Skip to content

Commit

Permalink
add multicentroids-linmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Jakub Tarnawski committed Nov 13, 2023
1 parent eac9d88 commit 1985d69
Showing 1 changed file with 81 additions and 12 deletions.
93 changes: 81 additions & 12 deletions tests/utils/partition_hmetis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,22 @@ void compute_centroid(const size_t dim, const T* points,
template<typename T>
void pick_random_points(const size_t dim, const T* points,
const std::vector<uint32_t>& point_ids,
float* subcentroids, const float* centroid, const unsigned num_subcentroids) {
float* subcentroids, const float* centroid,
const unsigned num_subcentroids) {
if (point_ids.empty()) {
for (int i = 0; i < num_subcentroids; ++i) {
assign_junk(dim, subcentroids + i * dim);
}
} else {
// the first subcentroid is the centroid
for (int j = 0; j < dim; ++j) {
subcentroids[j] = centroid[j];
}
subcentroids[j] = centroid[j];
}

// pick the other subcentroids as random points
std::random_device rd;
auto x = rd();
std::mt19937 generator(x);
std::random_device rd;
auto x = rd();
std::mt19937 generator(x);
std::uniform_int_distribution<uint32_t> int_dist(0, point_ids.size() - 1);
for (int i = 1; i < num_subcentroids; ++i) {
uint32_t random_point_id = point_ids[int_dist(generator)];
Expand All @@ -76,6 +77,58 @@ void pick_random_points(const size_t dim, const T* points,
}
}


template<typename T>
void pick_linmax_points(const size_t dim, const T* points,
const std::vector<uint32_t>& point_ids,
float* subcentroids, const float* centroid,
const unsigned num_subcentroids) {
if (point_ids.empty()) {
for (int i = 0; i < num_subcentroids; ++i) {
assign_junk(dim, subcentroids + i * dim);
}
} else {
// the first subcentroid is the centroid
for (int j = 0; j < dim; ++j) {
subcentroids[j] = centroid[j];
}

std::vector<bool> point_used(point_ids.size(), false);

// pick the other subcentroids as a selection of points on the convex hull
std::vector<float> random_direction(dim);
for (int i = 1; i < num_subcentroids; ++i) {
for (int j = 0; j < dim; ++j) {
random_direction[j] = sample_random_number(false);
// here, instead of picking a random direction, we could maybe
// pick a random query if we had access to a representative set of queries?
}
float max_value = -1e30;
int max_index = -1;
for (int j = 0; j < point_ids.size(); ++j) {
if (point_used[j])
continue;
float dot_product = 0.0;
for (int k = 0; k < dim; ++k) {
dot_product += points[point_ids[j] * dim + k] * random_direction[k];
}
if (dot_product > max_value) {
max_value = dot_product;
max_index = j;
}
}
if (max_index == -1) {
assign_junk(dim, subcentroids + i * dim);
} else {
for (int j = 0; j < dim; ++j) {
subcentroids[i * dim + j] = points[point_ids[max_index] * dim + j];
}
point_used[max_index] = true;
}
}
}
}

template<typename T>
void compute_subcentroids(const size_t dim, const T* points,
const std::vector<uint32_t>& point_ids,
Expand Down Expand Up @@ -430,6 +483,18 @@ int aux_main(const std::string &input_file,
}
// (subcluster_counts does not get filled in this case)
}
} else if (mode == "multicentroids-linmax") {
subcentroids =
std::make_unique<float[]>(num_shards * num_subcentroids * dim);
for (size_t shard_id = 0; shard_id < num_shards; ++shard_id) {
// compute subcentroids by maximizing random linear functions,
// and also pick the cluster center
pick_linmax_points<T>(
dim, points.get(), points_routed_to_shard[shard_id],
subcentroids.get() + shard_id * num_subcentroids * dim,
centroids.get() + shard_id * dim, num_subcentroids);
// (subcluster_counts does not get filled in this case)
}
}
// subcentroids do not get saved to a file

Expand Down Expand Up @@ -530,7 +595,7 @@ int aux_main(const std::string &input_file,
"approximation by centroids"
<< std::endl;
} else if (mode == "multicentroids" || mode == "multicentroids-random" ||
mode == "multicentroids-neighbors") {
mode == "multicentroids-neighbors" || mode == "multicentroids-linmax") {

constexpr int submode = 1;

Expand Down Expand Up @@ -563,14 +628,17 @@ int aux_main(const std::string &input_file,
}
} else if (submode == 2) {
if (mode == "multicentroids-neighbors" ||
mode == "multicentroids-random") {
mode == "multicentroids-random" ||
mode == "multicentroids-linmax") {
diskann::cout << "Error: submode 2 only works with multicentroids "
"as it needs subcluster_counts[] to be filled out"
<< std::endl;
return -1;
}
// 2: order shards by sum_subcentroid 1/distance
// (actually, better: sum (# pts in subcluster) / distance)

// (TODO: maybe try here sth like KDE: sum_{subcentroid s} exp(-dist(s,query)^2/0.1^2) )
for (size_t query_id = 0; query_id < num_queries; ++query_id) {
std::vector<std::pair<float, size_t>> shards_with_scores;
for (size_t shard_id = 0; shard_id < num_shards; ++shard_id) {
Expand Down Expand Up @@ -877,7 +945,8 @@ int main(int argc, char** argv) {
std::string("centroids")),
"How to route queries to shards (from_ground_truth / "
"centroids / multicentroids / multicentroids-random / "
"multicentroids-neighbors / geomedian / kde)");
"multicentroids-neighbors / multicentroids-linmax / "
"geomedian / kde)");
desc.add_options()("K,recall_at", po::value<unsigned>(&K)->default_value(0),
"Number of points returned per query");
desc.add_options()(
Expand Down Expand Up @@ -923,10 +992,10 @@ int main(int argc, char** argv) {

if (mode != "centroids" && mode != "multicentroids" && mode != "geomedian" &&
mode != "from_ground_truth" && mode != "multicentroids-random" &&
mode != "multicentroids-neighbors" && mode != "kde") {
mode != "multicentroids-neighbors" && mode != "multicentroids-linmax" && mode != "kde") {
diskann::cout
<< "mode must be centroids, multicentroids, multicentroids-random, "
"multicentroids-neighbors, geomedian, kde, or "
"multicentroids-neighbors, multicentroids-linmax, geomedian, kde, or "
"from_ground_truth"
<< std::endl;
return -1;
Expand All @@ -945,7 +1014,7 @@ int main(int argc, char** argv) {
}

if ((mode == "multicentroids" || mode == "multicentroids-random" ||
mode == "multicentroids-neighbors") &&
mode == "multicentroids-neighbors" || mode == "multicentroids-linmax") &&
num_subcentroids == 0) {
diskann::cout << "if multicentroids mode, must specify num_subcentroids"
<< std::endl;
Expand Down

0 comments on commit 1985d69

Please sign in to comment.