From 4f58b387b6befd569950d6a4713925c1327dede3 Mon Sep 17 00:00:00 2001 From: Paris Morgan Date: Wed, 16 Oct 2024 08:06:37 -0700 Subject: [PATCH] Fix types in qv_partition_with_scores() and train_no_init() (#546) --- src/include/detail/flat/qv.h | 2 +- src/include/index/kmeans.h | 17 ++--- src/include/test/unit_kmeans.cc | 113 ++++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 11 deletions(-) diff --git a/src/include/detail/flat/qv.h b/src/include/detail/flat/qv.h index e0f377d29..e178f8783 100644 --- a/src/include/detail/flat/qv.h +++ b/src/include/detail/flat/qv.h @@ -536,7 +536,7 @@ auto qv_partition_with_scores( // Just need a single vector std::vector top_k(q.num_cols()); - std::vector top_k_scores(q.num_cols()); + std::vector top_k_scores(q.num_cols()); auto par = stdx::execution::indexed_parallel_policy{(size_t)nthreads}; stdx::range_for_each( diff --git a/src/include/index/kmeans.h b/src/include/index/kmeans.h index 2bcc464a2..4ac9edf38 100644 --- a/src/include/index/kmeans.h +++ b/src/include/index/kmeans.h @@ -254,7 +254,6 @@ void train_no_init( if (::num_vectors(training_set) == 0) { return; } - using feature_type = typename V::value_type; using centroid_feature_type = typename C::value_type; using index_type = size_t; @@ -276,10 +275,9 @@ void train_no_init( // How many centroids should we try to fix up size_t heap_size = std::ceil(reassign_ratio_ * static_cast(num_partitions_)) + 5; - auto high_scores = fixed_min_pair_heap< - feature_type, - index_type, - std::greater>(heap_size, std::greater()); + auto high_scores = + fixed_min_pair_heap>( + heap_size, std::greater()); auto low_degrees = fixed_min_pair_heap(heap_size); // @todo parallelize -- by partition @@ -326,7 +324,7 @@ void train_no_init( std::sort_heap(begin(high_scores), end(high_scores), [](auto a, auto b) { return std::get<0>(a) > std::get<0>(b); }); - for (size_t i = 0; i < size(low_degrees) && + for (size_t i = 0; i < std::min(size(low_degrees), size(high_scores)) && std::get<0>(low_degrees[i]) <= lower_degree_bound; ++i) { // std::cout << "i: " << i << " low_degrees: (" @@ -527,10 +525,9 @@ auto sub_kmeans( #ifdef REASSIGN // How many centroids should we try to fix up size_t heap_size = std::ceil(reassign_ratio * num_clusters) + 5; - auto high_scores = fixed_min_pair_heap< - feature_type, - index_type, - std::greater>(heap_size, std::greater()); + auto high_scores = + fixed_min_pair_heap>( + heap_size, std::greater()); auto low_degrees = fixed_min_pair_heap(heap_size); #endif diff --git a/src/include/test/unit_kmeans.cc b/src/include/test/unit_kmeans.cc index e1c499e21..889c9258e 100644 --- a/src/include/test/unit_kmeans.cc +++ b/src/include/test/unit_kmeans.cc @@ -288,3 +288,116 @@ TEST_CASE( // Verify results for kmeans_pp verify_centroids(centroids_pp); } + +TEST_CASE("test kmeans train_no_init random data", "[kmeans]") { + // Sample data: 6-dimensional vectors, 10 vectors total (column major) + std::vector data = { + 7, 6, 249, 3, 2, 2, 254, 249, 7, 0, 9, 3, 248, 255, 4, + 0, 249, 0, 251, 249, 245, 3, 250, 252, 6, 7, 5, 252, 4, 5, + 9, 9, 248, 254, 7, 1, 4, 1, 253, 5, 2, 255, 250, 6, 3, + 0, 2, 249, 0, 250, 5, 4, 5, 2, 99, 30, 3, 1, 55, 88}; + + ColMajorMatrix training_set(6, 10); // 6 rows, 10 columns + std::copy(data.begin(), data.end(), training_set.data()); + + // Initial centroids: 6-dimensional vectors, 3 centroids total + std::vector centroids_data = { + 3, + 5, + 250, + 245, + 249, + 0, + 249, + 248, + 250, + 0, + 5, + 251, + 251, + 249, + 245, + 3, + 250, + 252}; + + ColMajorMatrix centroids(6, 3); + std::copy(centroids_data.begin(), centroids_data.end(), centroids.data()); + + size_t dimension_ = 6; + size_t num_partitions_ = 3; + uint32_t max_iterations = 2; + float tol_ = 2.5e-05; + size_t num_threads_ = 12; + float reassign_ratio_ = 0.075; + + CHECK(centroids.num_rows() == dimension_); + CHECK(centroids.num_cols() == num_partitions_); + + train_no_init( + training_set, + centroids, + dimension_, + num_partitions_, + max_iterations, + tol_, + num_threads_, + reassign_ratio_); + + CHECK(centroids.num_rows() == dimension_); + CHECK(centroids.num_cols() == num_partitions_); + + { + ColMajorMatrix original_centroids(6, 3); + std::copy( + centroids_data.begin(), + centroids_data.end(), + original_centroids.data()); + float max_diff = 0.0; + for (size_t i = 0; i < centroids.num_cols(); ++i) { + float diff = + sum_of_squares_distance{}(centroids[i], original_centroids[i]); + max_diff = std::max(max_diff, diff); + } + REQUIRE_THAT(max_diff, Catch::Matchers::WithinAbs(91858.75f, 1e-2)); + } +} + +TEST_CASE("test kmeans train_no_init training_set is empty", "[kmeans]") { + ColMajorMatrix training_set(0, 0); // Empty training set + ColMajorMatrix centroids(0, 0); // Empty centroids + + train_no_init(training_set, centroids, 0, 0, 2, 0.00001, 12, 0.075); + + CHECK(centroids.num_cols() == 0); // Expect centroids to remain empty + CHECK(centroids.num_rows() == 0); +} + +TEST_CASE( + "test kmeans train_no_init number of centroids exceeds data points", + "[kmeans]") { + std::vector small_data = {1, 2, 3, 4, 5, 6}; + ColMajorMatrix small_training_set(6, 1); // 6 rows, 1 column + std::copy(small_data.begin(), small_data.end(), small_training_set.data()); + + ColMajorMatrix more_centroids( + 6, 3); // More centroids than data points + + train_no_init( + small_training_set, more_centroids, 6, 3, 2, 0.00001, 12, 0.075); + + CHECK(more_centroids.num_cols() == 3); // Verify centroids were generated + for (size_t i = 0; i < more_centroids.num_cols(); ++i) { + // Ensure some centroids match the data point and the rest are zeros + bool is_zero = std::all_of( + more_centroids[i].begin(), more_centroids[i].end(), [](float val) { + return val == 0.0f; + }); + if (!is_zero) { + CHECK(std::equal( + more_centroids[i].begin(), + more_centroids[i].end(), + small_data.begin())); + } + } +}