Skip to content

Commit

Permalink
Fixes to utility functions and apps to support multi-filter queries
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Jun 17, 2024
1 parent 39a2005 commit 5473656
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 22 deletions.
68 changes: 55 additions & 13 deletions apps/search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common_includes.h"
#include <boost/program_options.hpp>

#include "utils.h"
#include "index.h"
#include "disk_utils.h"
#include "math_utils.h"
Expand Down Expand Up @@ -47,6 +48,44 @@ void print_stats(std::string category, std::vector<float> percentiles, std::vect
diskann::cout << std::endl;
}

template<typename T, typename LabelT>
void parse_labels_of_query(const std::string &filters_for_query,
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex,
std::vector<LabelT> &label_ids_for_query)
{
std::vector<std::string> label_strs_for_query;
diskann::split_string(filters_for_query, FILTER_OR_SEPARATOR, label_strs_for_query);
for (auto &label_str_for_query : label_strs_for_query)
{
label_ids_for_query.push_back(pFlashIndex->get_converted_label(label_str_for_query));

This comment has been minimized.

Copy link
@MS-Renan

MS-Renan Jun 20, 2024

get_converted_label can raise exception if label does not exist in map

}
}

template<typename T, typename LabelT>
void populate_label_ids(const std::vector<std::string> &filters_of_queries,
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex,
std::vector<std::vector<LabelT>> &label_ids_of_queries, bool apply_one_to_all, uint32_t query_count)
{
if (apply_one_to_all)
{
std::vector<LabelT> label_ids_of_query;
parse_labels_of_query(filters_of_queries[0], pFlashIndex, label_ids_of_query);
for (uint32_t i = 0; i < query_count; i++)
{
label_ids_of_queries.push_back(label_ids_of_query);
}
}
else
{
for (auto &filters_of_query : filters_of_queries)
{
std::vector<LabelT> label_ids_of_query;
parse_labels_of_query(filters_of_query, pFlashIndex, label_ids_of_query);
label_ids_of_queries.push_back(label_ids_of_query);
}
}
}

template <typename T, typename LabelT = uint32_t>
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
const std::string &result_output_prefix, const std::string &query_file, std::string &gt_file,
Expand Down Expand Up @@ -173,6 +212,14 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
diskann::cout << "..done" << std::endl;
}

std::vector<std::vector<LabelT>> per_query_label_ids;
if (filtered_search)
{
populate_label_ids(query_filters, _pFlashIndex, per_query_label_ids, (query_filters.size() == 1), query_num );
}



diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
diskann::cout.precision(2);

Expand Down Expand Up @@ -236,19 +283,10 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
}
else
{
LabelT label_for_search;
if (query_filters.size() == 1)
{ // one label for all queries
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
}
else
{ // one label for each query
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
}
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
use_reorder_data, stats + i);
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, per_query_label_ids[i],
search_io_limit, use_reorder_data, stats + i);
}
}
auto e = std::chrono::high_resolution_clock::now();
Expand All @@ -270,6 +308,9 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
auto mean_cpuus = diskann::get_mean_stats<float>(stats, query_num,
[](const diskann::QueryStats &stats) { return stats.cpu_us; });

auto mean_hops = diskann::get_mean_stats<uint32_t>(
stats, query_num, [](const diskann::QueryStats &stats) { return stats.n_hops; });

double recall = 0;
if (calc_recall_flag)
{
Expand All @@ -283,10 +324,12 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
<< std::setw(16) << mean_cpuus;
if (calc_recall_flag)
{
diskann::cout << std::setw(16) << recall << std::endl;
diskann::cout << std::setw(16) << recall << std::endl ;
}
else
{
diskann::cout << std::endl;
}
delete[] stats;
}

Expand Down Expand Up @@ -443,7 +486,6 @@ int main(int argc, char **argv)
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
}

try
{
if (!query_filters.empty() && label_type == "ushort")
Expand Down
2 changes: 1 addition & 1 deletion include/abstract_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ template <typename data_t> class AbstractDataStore

// Return number of points returned
virtual location_t load(const std::string &filename, size_t offset) = 0;
virtual location_t load(AlignedFileReader &reader, size_t offset) = 0;
//virtual location_t load(AlignedFileReader &reader, size_t offset) = 0;

// Why does store take num_pts? Since store only has capacity, but we allow
// resizing we can end up in a situation where the store has spare capacity.
Expand Down
4 changes: 2 additions & 2 deletions include/in_mem_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_
virtual ~InMemDataStore();

virtual location_t load(const std::string &filename, size_t offset = 0) override;
virtual location_t load(AlignedFileReader &reader, size_t offset = 0) override;
//virtual location_t load(AlignedFileReader &reader, size_t offset = 0) override;
virtual size_t save(const std::string &filename, const location_t num_pts) override;
virtual size_t save(std::ofstream &writer, const location_t num_pts, size_t offset) override;

Expand Down Expand Up @@ -63,7 +63,7 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_

virtual location_t load_impl(const std::string &filename, size_t offset);
#ifdef EXEC_ENV_OLS
virtual location_t load_impl(AlignedFileReader &reader, size_t offset);
//virtual location_t load_impl(AlignedFileReader &reader, size_t offset);
#endif

private:
Expand Down
5 changes: 4 additions & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ typedef int FileHandle;

#ifdef EXEC_ENV_OLS
#include "content_buf.h"
#include "memory_mapped_files.h"
#include "memory_mapper.h"

This comment has been minimized.

Copy link
@MS-Renan

MS-Renan Jun 20, 2024

How come this was changed?

#endif

// taken from
Expand All @@ -56,6 +56,7 @@ typedef int FileHandle;

#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
#define PBWIDTH 60
#define FILTER_OR_SEPARATOR "|"

inline bool file_exists_impl(const std::string &name, bool dirCheck = false)
{
Expand Down Expand Up @@ -681,6 +682,8 @@ DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_s
DISKANN_DLLEXPORT double calculate_range_search_recall(unsigned num_queries,
std::vector<std::vector<uint32_t>> &groundtruth,
std::vector<std::vector<uint32_t>> &our_results);
DISKANN_DLLEXPORT void split_string(const std::string &string_to_split, const std::string &delimiter,
std::vector<std::string> &pieces);

template <typename T>
inline void load_bin(const std::string &bin_file, std::unique_ptr<T[]> &data, size_t &npts, size_t &dim,
Expand Down
8 changes: 4 additions & 4 deletions src/in_mem_data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ template <typename data_t> location_t InMemDataStore<data_t>::load(const std::st
return load_impl(filename, offset);
}

template <typename data_t> location_t InMemDataStore<data_t>::load(AlignedFileReader &reader, size_t offset)
{
return load_impl(reader, offset);
}
//template <typename data_t> location_t InMemDataStore<data_t>::load(AlignedFileReader &reader, size_t offset)
//{
// return load_impl(reader, offset);
//}

#ifdef EXEC_ENV_OLS
template <typename data_t> location_t InMemDataStore<data_t>::load_impl(AlignedFileReader &reader, size_t offset)
Expand Down
2 changes: 1 addition & 1 deletion src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2065,7 +2065,7 @@ LabelT Index<T, TagT, LabelT>::get_converted_label(const std::string &raw_label)
return _universal_label;
}
std::stringstream stream;
stream << "Unable to find label in the Label Map";
stream << "Unable to find label" << raw_label << "in the label map ";
diskann::cerr << stream.str() << std::endl;
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
}
Expand Down
16 changes: 16 additions & 0 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,22 @@ double calculate_range_search_recall(uint32_t num_queries, std::vector<std::vect
return total_recall / (num_queries);
}

void split_string(const std::string &string_to_split, const std::string &delimiter, std::vector<std::string> &pieces)
{
size_t start = 0;
size_t end;
while ((end = string_to_split.find(delimiter, start)) != std::string::npos)
{
pieces.push_back(string_to_split.substr(start, end - start));
start = end + delimiter.length();
}
if (start != string_to_split.length())
{
pieces.push_back(string_to_split.substr(start, string_to_split.length() - start));
}
}


#ifdef EXEC_ENV_OLS
void get_bin_metadata(AlignedFileReader &reader, size_t &npts, size_t &ndim, size_t offset)
{
Expand Down

0 comments on commit 5473656

Please sign in to comment.