Skip to content

Commit

Permalink
added dynamic list sizing for range search in diskann
Browse files Browse the repository at this point in the history
  • Loading branch information
ravishankar authored and harsha-simhadri committed Sep 29, 2021
1 parent 10c1b3a commit ca243ff
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 53 deletions.
9 changes: 6 additions & 3 deletions include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ namespace diskann {


DISKANN_DLLEXPORT _u32 range_search(const T *query1, const double range,
const _u64 l_search, _u64* indices, float* distances,
const _u64 beam_width,
QueryStats *stats = nullptr);
const _u64 min_l_search,
const _u64 max_l_search,
std::vector<_u64> & indices,
std::vector<float> &distances,
const _u64 min_beam_width,
QueryStats * stats);

std::shared_ptr<AlignedFileReader> &reader;
protected:
Expand Down
57 changes: 39 additions & 18 deletions src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1214,25 +1214,46 @@ namespace diskann {
// range search returns results of all neighbors within distance of range. indices and distances need to be pre-allocated of size l_search
// and the return value is the number of matching hits.

template<typename T>
template<typename T>
_u32 PQFlashIndex<T>::range_search(const T *query1, const double range,
const _u64 l_search, _u64* indices, float* distances,
const _u64 beam_width,
QueryStats *stats) {
_u32 res_count = 0;
this->cached_beam_search(query1, l_search, l_search, indices, distances, beam_width, stats);
for (_u32 i = 0; i < l_search; i++) {
//std::cout<<distances[i]<<" ";
if (distances[i] > (float) range) {
res_count = i;
break;
} else if (i == l_search -1)
res_count = l_search;
}
//std::cout<<"\n\n"<<std::endl;
//std::cout<<res_count<< std::endl;
return res_count;
}
const _u64 min_l_search,
const _u64 max_l_search,
std::vector<_u64> & indices,
std::vector<float> &distances,
const _u64 min_beam_width,
QueryStats * stats) {
_u32 res_count = 0;

bool stop_flag = false;

_u32 l_search = min_l_search; // starting size of the candidate list
while (!stop_flag) {
indices.resize(l_search);
distances.resize(l_search);
_u64 cur_bw =
min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5;
cur_bw = (cur_bw > 100) ? 100 : cur_bw;
for (auto &x : distances)
x = std::numeric_limits<float>::max();
this->cached_beam_search(query1, l_search, l_search, indices.data(),
distances.data(), cur_bw, stats);
for (_u32 i = 0; i < l_search; i++) {
if (distances[i] > (float) range) {
res_count = i;
break;
} else if (i == l_search - 1)
res_count = l_search;
}
if (res_count < (_u32)(l_search / 2.0))
stop_flag = true;
l_search = l_search * 2;
if (l_search > max_l_search)
stop_flag = true;
}
indices.resize(res_count);
distances.resize(res_count);
return res_count;
}

#ifdef EXEC_ENV_OLS
template<typename T>
Expand Down
69 changes: 37 additions & 32 deletions tests/range_search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ void print_stats(std::string category, std::vector<float> percentiles,
template<typename T>
int search_disk_index(int argc, char** argv) {
// load query bin
T* query = nullptr;
// unsigned* gt_ids = nullptr;
// float* gt_dists = nullptr;
std::vector<std::vector<_u32>> groundtruth_ids;
T* query = nullptr;

std::vector<std::vector<_u32>> groundtruth_ids;
size_t query_num, query_dim, query_aligned_dim, gt_num;
std::vector<_u64> Lvec;

Expand Down Expand Up @@ -89,20 +88,18 @@ std::vector<std::vector<_u32>> groundtruth_ids;
_u32 beamwidth = std::atoi(argv[ctr++]);
std::string query_bin(argv[ctr++]);
std::string truthset_bin(argv[ctr++]);
double search_range = std::atof(argv[ctr++]);
double search_range = std::atof(argv[ctr++]);
std::string result_output_prefix(argv[ctr++]);

bool calc_recall_flag = false;

for (; ctr < (_u32) argc; ctr++) {
_u64 curL = std::atoi(argv[ctr]);
Lvec.push_back(curL);
Lvec.push_back(curL);
}

if (Lvec.size() == 0) {
diskann::cout
<< "No valid Lsearch found."
<< std::endl;
diskann::cout << "No valid Lsearch found." << std::endl;
return -1;
}

Expand All @@ -116,8 +113,11 @@ std::vector<std::vector<_u32>> groundtruth_ids;
query_aligned_dim);

if (file_exists(truthset_bin)) {
diskann::load_range_truthset(truthset_bin, groundtruth_ids, gt_num); // use for range search type of truthset
// diskann::prune_truthset_for_range(truthset_bin, search_range, groundtruth_ids, gt_num); // use for traditional truthset
diskann::load_range_truthset(
truthset_bin, groundtruth_ids,
gt_num); // use for range search type of truthset
// diskann::prune_truthset_for_range(truthset_bin, search_range,
// groundtruth_ids, gt_num); // use for traditional truthset
if (gt_num != query_num) {
diskann::cout
<< "Error. Mismatch in number of queries and ground truth data"
Expand Down Expand Up @@ -216,20 +216,14 @@ std::vector<std::vector<_u32>> groundtruth_ids;
<< std::endl;

std::vector<std::vector<std::vector<uint32_t>>> query_result_ids(Lvec.size());
std::vector<_u64> indices;
std::vector<float> distances;

uint32_t optimized_beamwidth = 2;
uint32_t max_list_size = 10000;

for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) {
_u64 L = Lvec[test_id];
indices.clear();
distances.clear();
indices.resize(L*query_num);
distances.resize(L*query_num);

if (beamwidth <= 0) {
// diskann::cout<<"Tuning beamwidth.." << std::endl;
optimized_beamwidth =
optimize_beamwidth(_pFlashIndex, warmup, warmup_num,
warmup_aligned_dim, L, optimized_beamwidth);
Expand All @@ -240,20 +234,19 @@ std::vector<std::vector<_u32>> groundtruth_ids;
query_result_ids[test_id].resize(query_num);

diskann::QueryStats* stats = new diskann::QueryStats[query_num];
auto s = std::chrono::high_resolution_clock::now();

auto s = std::chrono::high_resolution_clock::now();
#pragma omp parallel for schedule(dynamic, 1)
for (_s64 i = 0; i < (int64_t) query_num; i++) {
_u32 res_count =
_pFlashIndex->range_search(
query + (i * query_aligned_dim), search_range, L,
indices.data() + i*L, distances.data() + i *L,
optimized_beamwidth, stats + i);
// std::cout<<res_count <<" ";
query_result_ids[test_id][i].reserve(res_count);
query_result_ids[test_id][i].resize(res_count);
for(_u32 idx = 0; idx< res_count; idx++)
query_result_ids[test_id][i][idx] = indices[i*L + idx];
std::vector<_u64> indices;
std::vector<float> distances;
_u32 res_count = _pFlashIndex->range_search(
query + (i * query_aligned_dim), search_range, L, max_list_size, indices,
distances, optimized_beamwidth, stats + i);
query_result_ids[test_id][i].reserve(res_count);
query_result_ids[test_id][i].resize(res_count);
for (_u32 idx = 0; idx < res_count; idx++)
query_result_ids[test_id][i][idx] = indices[idx];
}
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
Expand All @@ -276,16 +269,28 @@ std::vector<std::vector<_u32>> groundtruth_ids;
[](const diskann::QueryStats& stats) { return stats.cpu_us; });

float recall = 0;
float ratio_of_sums = 0;
if (calc_recall_flag) {
recall = diskann::calculate_range_search_recall(query_num, groundtruth_ids, query_result_ids[test_id]);
recall = diskann::calculate_range_search_recall(
query_num, groundtruth_ids, query_result_ids[test_id]);

_u32 total_true_positive = 0;
_u32 total_positive = 0;
for (_u32 i = 0; i < query_num; i++) {
total_true_positive += query_result_ids[test_id][i].size();
total_positive += groundtruth_ids[i].size();
}

ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive);
}

diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth
<< std::setw(16) << qps << std::setw(16) << mean_latency
<< std::setw(16) << latency_999 << std::setw(16) << mean_ios
<< std::setw(16) << mean_cpuus;
if (calc_recall_flag) {
diskann::cout << std::setw(16) << recall << std::endl;
diskann::cout << std::setw(16) << recall << "," << ratio_of_sums
<< std::endl;
} else
diskann::cout << std::endl;
}
Expand Down

0 comments on commit ca243ff

Please sign in to comment.