Skip to content

Commit 61af797

Browse files
author
Jakub Tarnawski
committedNov 10, 2023
exact_kde mode for partition_hmetis
1 parent 56c021d commit 61af797

File tree

3 files changed

+120
-11
lines changed

3 files changed

+120
-11
lines changed
 

‎include/math_utils.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@ namespace math_utils {
1313
// compute l2-squared norms of data stored in row major num_points * dim,
1414
// needs
1515
// to be pre-allocated
16-
void compute_vecs_l2sq(float* vecs_l2sq, float* data, const size_t num_points,
16+
void compute_vecs_l2sq(float* vecs_l2sq, const float* data, const size_t num_points,
1717
const size_t dim);
1818

1919
void rotate_data_randomly(float* data, size_t num_points, size_t dim,
2020
float* rot_mat, float*& new_mat,
2121
bool transpose_rot = false);
2222

23+
DISKANN_DLLEXPORT std::unique_ptr<float[]> compute_all_distances(
24+
const float* const points, const size_t num_points, const size_t dim,
25+
const float* const centers, const size_t num_centers);
26+
2327
// calculate closest center to data of num_points * dim (row major)
2428
// centers is num_centers * dim (row major)
2529
// data_l2sq has pre-computed squared norms of data

‎src/math_utils.cpp

+47-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace math_utils {
2121
// compute l2-squared norms of data stored in row major num_points * dim,
2222
// needs
2323
// to be pre-allocated
24-
void compute_vecs_l2sq(float* vecs_l2sq, float* data, const size_t num_points,
24+
void compute_vecs_l2sq(float* vecs_l2sq, const float* data, const size_t num_points,
2525
const size_t dim) {
2626
#pragma omp parallel for schedule(static, 8192)
2727
for (int64_t n_iter = 0; n_iter < (_s64) num_points; n_iter++) {
@@ -199,6 +199,52 @@ namespace math_utils {
199199
delete[] pts_norms_squared;
200200
}
201201

202+
std::unique_ptr<float[]> compute_all_distances(
203+
const float* const points, const size_t num_points, const size_t dim,
204+
const float* const centers, const size_t num_centers) {
205+
206+
float* centers_l2sq = new float[num_centers];
207+
float* pts_l2sq = new float[num_points];
208+
209+
compute_vecs_l2sq(pts_l2sq, points, num_points, dim);
210+
compute_vecs_l2sq(centers_l2sq, centers, num_centers, dim);
211+
212+
float* ones_a = new float[num_centers];
213+
float* ones_b = new float[num_points];
214+
215+
for (size_t i = 0; i < num_centers; i++) {
216+
ones_a[i] = 1.0;
217+
}
218+
for (size_t i = 0; i < num_points; i++) {
219+
ones_b[i] = 1.0;
220+
}
221+
222+
std::unique_ptr<float[]> dist_matrix = std::make_unique<float[]>(
223+
num_points * num_centers);
224+
225+
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT) num_points,
226+
(MKL_INT) num_centers, (MKL_INT) 1, 1.0f, pts_l2sq,
227+
(MKL_INT) 1, ones_a, (MKL_INT) 1, 0.0f, dist_matrix.get(),
228+
(MKL_INT) num_centers);
229+
230+
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT) num_points,
231+
(MKL_INT) num_centers, (MKL_INT) 1, 1.0f, ones_b, (MKL_INT) 1,
232+
centers_l2sq, (MKL_INT) 1, 1.0f, dist_matrix.get(),
233+
(MKL_INT) num_centers);
234+
235+
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, (MKL_INT) num_points,
236+
(MKL_INT) num_centers, (MKL_INT) dim, -2.0f, points,
237+
(MKL_INT) dim, centers, (MKL_INT) dim, 1.0f, dist_matrix.get(),
238+
(MKL_INT) num_centers);
239+
240+
delete[] ones_a;
241+
delete[] ones_b;
242+
delete[] centers_l2sq;
243+
delete[] pts_l2sq;
244+
return dist_matrix;
245+
}
246+
247+
202248
// if to_subtract is 1, will subtract nearest center from each row. Else will
203249
// add. Output will be in data_load iself.
204250
// Nearest centers need to be provided in closst_centers.

‎tests/utils/partition_hmetis.cpp

+68-9
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ int aux_main(const std::string &input_file,
250250
const std::string& query_file, const std::string& gt_file,
251251
const std::string& hmetis_file, const std::string& mode,
252252
const unsigned K, const unsigned query_fanout,
253-
const unsigned num_subcentroids) {
253+
const unsigned num_subcentroids,
254+
const float sigma) {
254255

255256
// load dataset
256257
// TODO for later: handle datasets that don't fit in memory
@@ -286,7 +287,7 @@ int aux_main(const std::string &input_file,
286287
points_routed_to_shard[shard_of_point[point_id]].push_back(point_id);
287288
}
288289

289-
290+
290291
// write shards to disk
291292
diskann::cout << "Writing shards to disk..." << std::endl;
292293
int ret = write_shards_to_disk<T>(output_file_prefix, false, points.get(),
@@ -640,6 +641,55 @@ int aux_main(const std::string &input_file,
640641
"approximation by MULTIcentroids"
641642
<< std::endl;
642643

644+
} else if (mode == "exact_kde") {
645+
646+
// compute distances from each query to each point
647+
// (we're doing worse than brute force, but it's just an experiment)
648+
std::unique_ptr<float[]> queries_float =
649+
std::make_unique<float[]>(num_queries * dim);
650+
diskann::convert_types<T, float>(queries.get(), queries_float.get(),
651+
num_queries, dim);
652+
// now do the same with the points
653+
std::unique_ptr<float[]> points_float =
654+
std::make_unique<float[]>(num_points * dim);
655+
diskann::convert_types<T, float>(points.get(), points_float.get(),
656+
num_points, dim);
657+
658+
constexpr size_t num_queries_per_batch = 100;
659+
for (size_t query_from = 0; query_from < num_queries;
660+
query_from += num_queries_per_batch) {
661+
const size_t query_to =
662+
std::min(query_from + num_queries_per_batch, num_queries);
663+
std::unique_ptr<float[]> distances_for_batch =
664+
math_utils::compute_all_distances(
665+
queries_float.get() + query_from * dim, query_to - query_from,
666+
dim, points_float.get(), num_points);
667+
for (size_t query_id = query_from; query_id < query_to; ++query_id) {
668+
query_to_shards.emplace_back();
669+
// compute exact KDE values for each shard
670+
std::vector<std::pair<float, size_t>> shards_with_scores;
671+
for (size_t shard_id = 0; shard_id < num_shards; ++shard_id) {
672+
// compute score (KDE) of shard_id for query_id
673+
float kde = 0.0;
674+
const float *distances_for_this_query =
675+
distances_for_batch.get() + (query_id - query_from) * num_points;
676+
for (const uint32_t point_id : points_routed_to_shard[shard_id]) {
677+
const float dist = distances_for_this_query[point_id];
678+
kde += exp(-dist * dist / (2 * sigma * sigma));
679+
}
680+
shards_with_scores.emplace_back(-kde, shard_id);
681+
}
682+
sort(shards_with_scores.begin(), shards_with_scores.end());
683+
for (int i = 0; i < num_shards; ++i) {
684+
const size_t shard_id = shards_with_scores[i].second;
685+
query_to_shards[query_id].emplace_back(
686+
shard_id, shard_to_count_of_GT_pts[query_id][shard_id]);
687+
// shard_to_count_of_GT_pts[query_id][shard_id] will be(come)
688+
// 0 if wasn't present
689+
}
690+
}
691+
}
692+
643693
} else {
644694
diskann::cout << "unsupported mode?" << std::endl;
645695
return -1;
@@ -690,7 +740,7 @@ int aux_main(const std::string &input_file,
690740
<< std::endl;
691741

692742
// 2. histogram of fanouts
693-
const size_t max_interesting_fanout =
743+
const size_t max_interesting_fanout = num_shards < 100 ? num_shards : K < 100 ? 100 :
694744
(mode == "from_ground_truth") ? K : 1.5 * K;
695745
std::vector<size_t> num_queries_with_fanout(max_interesting_fanout + 1,
696746
0);
@@ -771,6 +821,7 @@ int aux_main(const std::string &input_file,
771821
int main(int argc, char** argv) {
772822
std::string input_file, output_file_prefix, query_file, gt_file, hmetis_file, mode;
773823
unsigned K, query_fanout, num_subcentroids;
824+
float sigma; // for exact_kde
774825

775826
std::string data_type;
776827

@@ -796,7 +847,7 @@ int main(int argc, char** argv) {
796847
std::string("centroids")),
797848
"How to route queries to shards (from_ground_truth / "
798849
"centroids / multicentroids / multicentroids-random / "
799-
"multicentroids-neighbors / geomedian)");
850+
"multicentroids-neighbors / geomedian / exact_kde)");
800851
desc.add_options()("K,recall_at", po::value<unsigned>(&K)->default_value(0),
801852
"Number of points returned per query");
802853
desc.add_options()(
@@ -814,6 +865,9 @@ int main(int argc, char** argv) {
814865
desc.add_options()("num_subcentroids",
815866
po::value<unsigned>(&num_subcentroids)->default_value(0),
816867
"The number of subcentroids (for multicentroids modes)");
868+
desc.add_options()("sigma",
869+
po::value<float>(&sigma)->default_value(-1.0),
870+
"sigma for exact_kde");
817871

818872
po::variables_map vm;
819873
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -836,10 +890,10 @@ int main(int argc, char** argv) {
836890

837891
if (mode != "centroids" && mode != "multicentroids" && mode != "geomedian" &&
838892
mode != "from_ground_truth" && mode != "multicentroids-random" &&
839-
mode != "multicentroids-neighbors") {
893+
mode != "multicentroids-neighbors" && mode != "exact_kde") {
840894
diskann::cout
841895
<< "mode must be centroids, multicentroids, multicentroids-random, "
842-
"multicentroids-neighbors, geomedian or "
896+
"multicentroids-neighbors, geomedian, exact_kde, or "
843897
"from_ground_truth"
844898
<< std::endl;
845899
return -1;
@@ -865,19 +919,24 @@ int main(int argc, char** argv) {
865919
return -1;
866920
}
867921

922+
if (mode == "exact_kde" && sigma < 0) {
923+
diskann::cout << "if exact_kde mode, must specify sigma" << std::endl;
924+
return -1;
925+
}
926+
868927
try {
869928
if (data_type == std::string("float")) {
870929
return aux_main<float>(input_file, output_file_prefix, query_file,
871930
gt_file, hmetis_file, mode, K, query_fanout,
872-
num_subcentroids);
931+
num_subcentroids, sigma);
873932
} else if (data_type == std::string("int8")) {
874933
return aux_main<int8_t>(input_file, output_file_prefix, query_file,
875934
gt_file, hmetis_file, mode, K, query_fanout,
876-
num_subcentroids);
935+
num_subcentroids, sigma);
877936
} else if (data_type == std::string("uint8")) {
878937
return aux_main<uint8_t>(input_file, output_file_prefix, query_file,
879938
gt_file, hmetis_file, mode, K, query_fanout,
880-
num_subcentroids);
939+
num_subcentroids, sigma);
881940
} else {
882941
std::cerr << "Unsupported data type. Use float or int8 or uint8"
883942
<< std::endl;

0 commit comments

Comments
 (0)
Please sign in to comment.