@@ -250,7 +250,8 @@ int aux_main(const std::string &input_file,
250
250
const std::string& query_file, const std::string& gt_file,
251
251
const std::string& hmetis_file, const std::string& mode,
252
252
const unsigned K, const unsigned query_fanout,
253
- const unsigned num_subcentroids) {
253
+ const unsigned num_subcentroids,
254
+ const float sigma) {
254
255
255
256
// load dataset
256
257
// TODO for later: handle datasets that don't fit in memory
@@ -286,7 +287,7 @@ int aux_main(const std::string &input_file,
286
287
points_routed_to_shard[shard_of_point[point_id]].push_back (point_id);
287
288
}
288
289
289
-
290
+
290
291
// write shards to disk
291
292
diskann::cout << " Writing shards to disk..." << std::endl;
292
293
int ret = write_shards_to_disk<T>(output_file_prefix, false , points.get (),
@@ -640,6 +641,55 @@ int aux_main(const std::string &input_file,
640
641
" approximation by MULTIcentroids"
641
642
<< std::endl;
642
643
644
+ } else if (mode == " exact_kde" ) {
645
+
646
+ // compute distances from each query to each point
647
+ // (we're doing worse than brute force, but it's just an experiment)
648
+ std::unique_ptr<float []> queries_float =
649
+ std::make_unique<float []>(num_queries * dim);
650
+ diskann::convert_types<T, float >(queries.get (), queries_float.get (),
651
+ num_queries, dim);
652
+ // now do the same with the points
653
+ std::unique_ptr<float []> points_float =
654
+ std::make_unique<float []>(num_points * dim);
655
+ diskann::convert_types<T, float >(points.get (), points_float.get (),
656
+ num_points, dim);
657
+
658
+ constexpr size_t num_queries_per_batch = 100 ;
659
+ for (size_t query_from = 0 ; query_from < num_queries;
660
+ query_from += num_queries_per_batch) {
661
+ const size_t query_to =
662
+ std::min (query_from + num_queries_per_batch, num_queries);
663
+ std::unique_ptr<float []> distances_for_batch =
664
+ math_utils::compute_all_distances (
665
+ queries_float.get () + query_from * dim, query_to - query_from,
666
+ dim, points_float.get (), num_points);
667
+ for (size_t query_id = query_from; query_id < query_to; ++query_id) {
668
+ query_to_shards.emplace_back ();
669
+ // compute exact KDE values for each shard
670
+ std::vector<std::pair<float , size_t >> shards_with_scores;
671
+ for (size_t shard_id = 0 ; shard_id < num_shards; ++shard_id) {
672
+ // compute score (KDE) of shard_id for query_id
673
+ float kde = 0.0 ;
674
+ const float *distances_for_this_query =
675
+ distances_for_batch.get () + (query_id - query_from) * num_points;
676
+ for (const uint32_t point_id : points_routed_to_shard[shard_id]) {
677
+ const float dist = distances_for_this_query[point_id];
678
+ kde += exp (-dist * dist / (2 * sigma * sigma));
679
+ }
680
+ shards_with_scores.emplace_back (-kde, shard_id);
681
+ }
682
+ sort (shards_with_scores.begin (), shards_with_scores.end ());
683
+ for (int i = 0 ; i < num_shards; ++i) {
684
+ const size_t shard_id = shards_with_scores[i].second ;
685
+ query_to_shards[query_id].emplace_back (
686
+ shard_id, shard_to_count_of_GT_pts[query_id][shard_id]);
687
+ // shard_to_count_of_GT_pts[query_id][shard_id] will be(come)
688
+ // 0 if wasn't present
689
+ }
690
+ }
691
+ }
692
+
643
693
} else {
644
694
diskann::cout << " unsupported mode?" << std::endl;
645
695
return -1 ;
@@ -690,7 +740,7 @@ int aux_main(const std::string &input_file,
690
740
<< std::endl;
691
741
692
742
// 2. histogram of fanouts
693
- const size_t max_interesting_fanout =
743
+ const size_t max_interesting_fanout = num_shards < 100 ? num_shards : K < 100 ? 100 :
694
744
(mode == " from_ground_truth" ) ? K : 1.5 * K;
695
745
std::vector<size_t > num_queries_with_fanout (max_interesting_fanout + 1 ,
696
746
0 );
@@ -771,6 +821,7 @@ int aux_main(const std::string &input_file,
771
821
int main (int argc, char ** argv) {
772
822
std::string input_file, output_file_prefix, query_file, gt_file, hmetis_file, mode;
773
823
unsigned K, query_fanout, num_subcentroids;
824
+ float sigma; // for exact_kde
774
825
775
826
std::string data_type;
776
827
@@ -796,7 +847,7 @@ int main(int argc, char** argv) {
796
847
std::string (" centroids" )),
797
848
" How to route queries to shards (from_ground_truth / "
798
849
" centroids / multicentroids / multicentroids-random / "
799
- " multicentroids-neighbors / geomedian)" );
850
+ " multicentroids-neighbors / geomedian / exact_kde )" );
800
851
desc.add_options ()(" K,recall_at" , po::value<unsigned >(&K)->default_value (0 ),
801
852
" Number of points returned per query" );
802
853
desc.add_options ()(
@@ -814,6 +865,9 @@ int main(int argc, char** argv) {
814
865
desc.add_options ()(" num_subcentroids" ,
815
866
po::value<unsigned >(&num_subcentroids)->default_value (0 ),
816
867
" The number of subcentroids (for multicentroids modes)" );
868
+ desc.add_options ()(" sigma" ,
869
+ po::value<float >(&sigma)->default_value (-1.0 ),
870
+ " sigma for exact_kde" );
817
871
818
872
po::variables_map vm;
819
873
po::store (po::parse_command_line (argc, argv, desc), vm);
@@ -836,10 +890,10 @@ int main(int argc, char** argv) {
836
890
837
891
if (mode != " centroids" && mode != " multicentroids" && mode != " geomedian" &&
838
892
mode != " from_ground_truth" && mode != " multicentroids-random" &&
839
- mode != " multicentroids-neighbors" ) {
893
+ mode != " multicentroids-neighbors" && mode != " exact_kde " ) {
840
894
diskann::cout
841
895
<< " mode must be centroids, multicentroids, multicentroids-random, "
842
- " multicentroids-neighbors, geomedian or "
896
+ " multicentroids-neighbors, geomedian, exact_kde, or "
843
897
" from_ground_truth"
844
898
<< std::endl;
845
899
return -1 ;
@@ -865,19 +919,24 @@ int main(int argc, char** argv) {
865
919
return -1 ;
866
920
}
867
921
922
+ if (mode == " exact_kde" && sigma < 0 ) {
923
+ diskann::cout << " if exact_kde mode, must specify sigma" << std::endl;
924
+ return -1 ;
925
+ }
926
+
868
927
try {
869
928
if (data_type == std::string (" float" )) {
870
929
return aux_main<float >(input_file, output_file_prefix, query_file,
871
930
gt_file, hmetis_file, mode, K, query_fanout,
872
- num_subcentroids);
931
+ num_subcentroids, sigma );
873
932
} else if (data_type == std::string (" int8" )) {
874
933
return aux_main<int8_t >(input_file, output_file_prefix, query_file,
875
934
gt_file, hmetis_file, mode, K, query_fanout,
876
- num_subcentroids);
935
+ num_subcentroids, sigma );
877
936
} else if (data_type == std::string (" uint8" )) {
878
937
return aux_main<uint8_t >(input_file, output_file_prefix, query_file,
879
938
gt_file, hmetis_file, mode, K, query_fanout,
880
- num_subcentroids);
939
+ num_subcentroids, sigma );
881
940
} else {
882
941
std::cerr << " Unsupported data type. Use float or int8 or uint8"
883
942
<< std::endl;
0 commit comments