diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index d119b8197..03d8754cb 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -117,9 +117,12 @@ namespace diskann { DISKANN_DLLEXPORT _u32 range_search(const T *query1, const double range, - const _u64 l_search, _u64* indices, float* distances, - const _u64 beam_width, - QueryStats *stats = nullptr); + const _u64 min_l_search, + const _u64 max_l_search, + std::vector<_u64> & indices, + std::vector &distances, + const _u64 min_beam_width, + QueryStats * stats); std::shared_ptr &reader; protected: diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 9c1bea8ab..c9f0664ec 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1214,25 +1214,46 @@ namespace diskann { // range search returns results of all neighbors within distance of range. indices and distances need to be pre-allocated of size l_search // and the return value is the number of matching hits. - template +template _u32 PQFlashIndex::range_search(const T *query1, const double range, - const _u64 l_search, _u64* indices, float* distances, - const _u64 beam_width, - QueryStats *stats) { -_u32 res_count = 0; -this->cached_beam_search(query1, l_search, l_search, indices, distances, beam_width, stats); -for (_u32 i = 0; i < l_search; i++) { - //std::cout< (float) range) { - res_count = i; - break; - } else if (i == l_search -1) - res_count = l_search; -} -//std::cout<<"\n\n"< & indices, + std::vector &distances, + const _u64 min_beam_width, + QueryStats * stats) { + _u32 res_count = 0; + + bool stop_flag = false; + + _u32 l_search = min_l_search; // starting size of the candidate list + while (!stop_flag) { + indices.resize(l_search); + distances.resize(l_search); + _u64 cur_bw = + min_beam_width > (l_search / 5) ? min_beam_width : l_search / 5; + cur_bw = (cur_bw > 100) ? 100 : cur_bw; + for (auto &x : distances) + x = std::numeric_limits::max(); + this->cached_beam_search(query1, l_search, l_search, indices.data(), + distances.data(), cur_bw, stats); + for (_u32 i = 0; i < l_search; i++) { + if (distances[i] > (float) range) { + res_count = i; + break; + } else if (i == l_search - 1) + res_count = l_search; + } + if (res_count < (_u32)(l_search / 2.0)) + stop_flag = true; + l_search = l_search * 2; + if (l_search > max_l_search) + stop_flag = true; + } + indices.resize(res_count); + distances.resize(res_count); + return res_count; + } #ifdef EXEC_ENV_OLS template diff --git a/tests/range_search_disk_index.cpp b/tests/range_search_disk_index.cpp index a181772ba..67a02713d 100644 --- a/tests/range_search_disk_index.cpp +++ b/tests/range_search_disk_index.cpp @@ -50,10 +50,9 @@ void print_stats(std::string category, std::vector percentiles, template int search_disk_index(int argc, char** argv) { // load query bin - T* query = nullptr; -// unsigned* gt_ids = nullptr; -// float* gt_dists = nullptr; -std::vector> groundtruth_ids; + T* query = nullptr; + + std::vector> groundtruth_ids; size_t query_num, query_dim, query_aligned_dim, gt_num; std::vector<_u64> Lvec; @@ -89,20 +88,18 @@ std::vector> groundtruth_ids; _u32 beamwidth = std::atoi(argv[ctr++]); std::string query_bin(argv[ctr++]); std::string truthset_bin(argv[ctr++]); - double search_range = std::atof(argv[ctr++]); + double search_range = std::atof(argv[ctr++]); std::string result_output_prefix(argv[ctr++]); bool calc_recall_flag = false; for (; ctr < (_u32) argc; ctr++) { _u64 curL = std::atoi(argv[ctr]); - Lvec.push_back(curL); + Lvec.push_back(curL); } if (Lvec.size() == 0) { - diskann::cout - << "No valid Lsearch found." - << std::endl; + diskann::cout << "No valid Lsearch found." << std::endl; return -1; } @@ -116,8 +113,11 @@ std::vector> groundtruth_ids; query_aligned_dim); if (file_exists(truthset_bin)) { - diskann::load_range_truthset(truthset_bin, groundtruth_ids, gt_num); // use for range search type of truthset -// diskann::prune_truthset_for_range(truthset_bin, search_range, groundtruth_ids, gt_num); // use for traditional truthset + diskann::load_range_truthset( + truthset_bin, groundtruth_ids, + gt_num); // use for range search type of truthset + // diskann::prune_truthset_for_range(truthset_bin, search_range, + // groundtruth_ids, gt_num); // use for traditional truthset if (gt_num != query_num) { diskann::cout << "Error. Mismatch in number of queries and ground truth data" @@ -216,20 +216,14 @@ std::vector> groundtruth_ids; << std::endl; std::vector>> query_result_ids(Lvec.size()); - std::vector<_u64> indices; - std::vector distances; uint32_t optimized_beamwidth = 2; + uint32_t max_list_size = 10000; for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { _u64 L = Lvec[test_id]; - indices.clear(); - distances.clear(); - indices.resize(L*query_num); - distances.resize(L*query_num); if (beamwidth <= 0) { - // diskann::cout<<"Tuning beamwidth.." << std::endl; optimized_beamwidth = optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth); @@ -240,20 +234,19 @@ std::vector> groundtruth_ids; query_result_ids[test_id].resize(query_num); diskann::QueryStats* stats = new diskann::QueryStats[query_num]; - - auto s = std::chrono::high_resolution_clock::now(); + + auto s = std::chrono::high_resolution_clock::now(); #pragma omp parallel for schedule(dynamic, 1) for (_s64 i = 0; i < (int64_t) query_num; i++) { - _u32 res_count = - _pFlashIndex->range_search( - query + (i * query_aligned_dim), search_range, L, - indices.data() + i*L, distances.data() + i *L, - optimized_beamwidth, stats + i); - // std::cout< indices; + std::vector distances; + _u32 res_count = _pFlashIndex->range_search( + query + (i * query_aligned_dim), search_range, L, max_list_size, indices, + distances, optimized_beamwidth, stats + i); + query_result_ids[test_id][i].reserve(res_count); + query_result_ids[test_id][i].resize(res_count); + for (_u32 idx = 0; idx < res_count; idx++) + query_result_ids[test_id][i][idx] = indices[idx]; } auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; @@ -276,8 +269,19 @@ std::vector> groundtruth_ids; [](const diskann::QueryStats& stats) { return stats.cpu_us; }); float recall = 0; + float ratio_of_sums = 0; if (calc_recall_flag) { - recall = diskann::calculate_range_search_recall(query_num, groundtruth_ids, query_result_ids[test_id]); + recall = diskann::calculate_range_search_recall( + query_num, groundtruth_ids, query_result_ids[test_id]); + + _u32 total_true_positive = 0; + _u32 total_positive = 0; + for (_u32 i = 0; i < query_num; i++) { + total_true_positive += query_result_ids[test_id][i].size(); + total_positive += groundtruth_ids[i].size(); + } + + ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive); } diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth @@ -285,7 +289,8 @@ std::vector> groundtruth_ids; << std::setw(16) << latency_999 << std::setw(16) << mean_ios << std::setw(16) << mean_cpuus; if (calc_recall_flag) { - diskann::cout << std::setw(16) << recall << std::endl; + diskann::cout << std::setw(16) << recall << "," << ratio_of_sums + << std::endl; } else diskann::cout << std::endl; }