From d6a3ba429700e785b1bc20f65fc25edddc5849fb Mon Sep 17 00:00:00 2001 From: yangzq50 <58433399+yangzq50@users.noreply.github.com> Date: Wed, 23 Oct 2024 21:08:10 +0800 Subject: [PATCH] Support scalar quantization for IVF index (#2090) ### What problem does this PR solve? Support 4-bit and 8-bit scalar quantization for IVF index Issue link:#2085 ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring --- src/storage/knn_index/emvb/emvb_index.cpp | 2 +- src/storage/knn_index/emvb/emvb_index.cppm | 2 +- .../knn_index/knn_ivf/ivf_index_storage.cpp | 4 +- .../knn_index/knn_ivf/ivf_index_storage.cppm | 2 +- .../knn_ivf/ivf_index_storage_parts.cpp | 527 +++++++++++++++--- 5 files changed, 458 insertions(+), 79 deletions(-) diff --git a/src/storage/knn_index/emvb/emvb_index.cpp b/src/storage/knn_index/emvb/emvb_index.cpp index 5e9adb0021..82db419322 100644 --- a/src/storage/knn_index/emvb/emvb_index.cpp +++ b/src/storage/knn_index/emvb/emvb_index.cpp @@ -438,7 +438,7 @@ EMVBQueryResultType EMVBIndex::query_token_num_helper(const f32 *query_ptr, u32 return query_token_num_helper(query_ptr, query_embedding_num, std::forward(query_args)...); } -template <> +template EMVBQueryResultType EMVBIndex::query_token_num_helper(const f32 *query_ptr, u32 query_embedding_num, auto &&...query_args) const { auto error_msg = fmt::format("EMVBIndex::GetQueryResult: query_embedding_num max value: {}, got {} instead.", current_max_query_token_num, diff --git a/src/storage/knn_index/emvb/emvb_index.cppm b/src/storage/knn_index/emvb/emvb_index.cppm index be2bfd5592..e7779b66b2 100644 --- a/src/storage/knn_index/emvb/emvb_index.cppm +++ b/src/storage/knn_index/emvb/emvb_index.cppm @@ -110,7 +110,7 @@ private: template EMVBQueryResultType query_token_num_helper(const f32 *query_ptr, u32 query_embedding_num, auto &&...query_args) const; - template <> + template EMVBQueryResultType query_token_num_helper(const f32 *query_ptr, u32 query_embedding_num, auto &&...query_args) const; template diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp b/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp index a92eeef079..a91544dd9c 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage.cpp @@ -270,9 +270,7 @@ void IVF_Index_Storage::SearchIndex(const KnnDistanceBase1 *knn_distance, const auto centroid_dists = MakeUniqueForOverwrite(nprobe); search_top_k_with_dis(nprobe, dimension, 1, query_f32_ptr, centroids_num, centroids_data, nprobe_result.data(), centroid_dists.get(), false); } - for (const auto part_id : nprobe_result) { - ivf_parts_storage_->SearchIndex(part_id, this, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); - } + ivf_parts_storage_->SearchIndex(nprobe_result, this, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); } } // namespace infinity diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm b/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm index 463225390e..0169289b77 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage.cppm @@ -69,7 +69,7 @@ public: virtual void AppendOneEmbedding(u32 part_id, const void *embedding_ptr, SegmentOffset segment_offset, const IVF_Centroids_Storage *ivf_centroids_storage) = 0; - virtual void SearchIndex(u32 part_id, + virtual void SearchIndex(const Vector &part_ids, const IVF_Index_Storage *ivf_index_storage, const KnnDistanceBase1 *knn_distance, const void *query_ptr, diff --git a/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp b/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp index cf11f184f3..ddaa81a852 100644 --- a/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp +++ b/src/storage/knn_index/knn_ivf/ivf_index_storage_parts.cpp @@ -15,7 +15,9 @@ module; #include +#include #include + module ivf_index_storage; import stl; @@ -42,6 +44,10 @@ import simd_functions; namespace infinity { +struct SearchIndexPartsReuseContext { + UniquePtr pq_query_ip_table_ = {}; +}; + class IVF_Part_Storage { const u32 part_id_ = std::numeric_limits::max(); @@ -81,7 +87,8 @@ class IVF_Part_Storage { const void *query_ptr, EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func) const = 0; + const std::function &add_result_func, + SearchIndexPartsReuseContext &context) const = 0; }; template @@ -101,6 +108,125 @@ class IVF_Parts_Storage_Info : public IVF_P void Train(const u32 training_embedding_num, const f32 *training_data, const IVF_Centroids_Storage *ivf_centroids_storage) final {} }; +UniquePtr GetTrainingResidual(const u32 training_embedding_num, + const f32 *training_data, + const IVF_Centroids_Storage *ivf_centroids_storage, + const u32 embedding_dimension, + const u32 centroids_num) { + auto residuals = MakeUniqueForOverwrite(training_embedding_num * embedding_dimension); + const f32 *centroids_data = ivf_centroids_storage->data(); + // (-0.5 * norm) for each centroid + UniquePtr centroid_norms_neg_half = MakeUniqueForOverwrite(centroids_num); + { + // prepare centroid_norms_neg_half_ + const f32 *centroids_data_ptr = centroids_data; + for (u32 i = 0; i < centroids_num; ++i) { + centroid_norms_neg_half[i] = -0.5f * L2NormSquare(centroids_data_ptr, embedding_dimension); + centroids_data_ptr += embedding_dimension; + } + } + // distance: for every embedding, e * c - 0.5 * c^2, find max + const auto dist_table = MakeUniqueForOverwrite(training_embedding_num * centroids_num); + matrixA_multiply_transpose_matrixB_output_to_C(training_data, + centroids_data, + training_embedding_num, + centroids_num, + embedding_dimension, + dist_table.get()); + for (u32 i = 0; i < training_embedding_num; ++i) { + const f32 *embedding_data_ptr = training_data + i * embedding_dimension; + f32 *output_ptr = residuals.get() + i * embedding_dimension; + f32 max_neg_distance = std::numeric_limits::lowest(); + u64 max_id = 0; + const f32 *dist_ptr = dist_table.get() + i * centroids_num; + for (u32 k = 0; k < centroids_num; ++k) { + if (const f32 neg_distance = dist_ptr[k] + centroid_norms_neg_half[k]; neg_distance > max_neg_distance) { + max_neg_distance = neg_distance; + max_id = k; + } + } + const f32 *centroids_data_ptr = centroids_data + max_id * embedding_dimension; + for (u32 j = 0; j < embedding_dimension; ++j) { + output_ptr[j] = embedding_data_ptr[j] - centroids_data_ptr[j]; + } + } + return residuals; +} + +// quantize with middle 99% range +// r = a * n + b (*: Hadamard product), v = c + r + e +// common a, b for all residuals +// n in [0, 15) or [0, 255) +// query x, centroid c +// dot(x, v) = dot(x, (c + r + e)) = dot(x, c + b) + dot((x * a), n) + dot(v, e) +// dot(v, v) = dot(c + b + a * n + e, v) = dot(c + b, c + b) + dot(a^2, n^2) + 2 * dot ((c + b) * a, n) + 2 * dot(v, e) +// l2(a * n + b + c - x) = dot(a^2, n^2) + l2(b + c - x) + 2 * dot(a * (b + c - x), n) + +template <> +class IVF_Parts_Storage_Info : public IVF_Parts_Storage { + const u32 sq_bits_ = 0; // 4 or 8 + Vector common_vec_a_; + Vector common_vec_b_; + +public: + IVF_Parts_Storage_Info(const u32 embedding_dimension, + const u32 centroids_num, + const EmbeddingDataType embedding_data_type, + const IndexIVFStorageOption &ivf_storage_option) + : IVF_Parts_Storage(embedding_dimension, centroids_num), sq_bits_(ivf_storage_option.scalar_quantization_bits_), + common_vec_a_(embedding_dimension), common_vec_b_(embedding_dimension) { + assert(sq_bits_ == 4 || sq_bits_ == 8); + } + ~IVF_Parts_Storage_Info() override = default; + + const f32 *data_a() const { return common_vec_a_.data(); } + const f32 *data_b() const { return common_vec_b_.data(); } + + void Save(LocalFileHandle &file_handle) const override { + file_handle.Append(common_vec_a_.data(), embedding_dimension() * sizeof(f32)); + file_handle.Append(common_vec_b_.data(), embedding_dimension() * sizeof(f32)); + } + void Load(LocalFileHandle &file_handle) override { + file_handle.Read(common_vec_a_.data(), embedding_dimension() * sizeof(f32)); + file_handle.Read(common_vec_b_.data(), embedding_dimension() * sizeof(f32)); + } + void Train(const u32 training_embedding_num, const f32 *training_data, const IVF_Centroids_Storage *ivf_centroids_storage) final { + const auto residuals = + GetTrainingResidual(training_embedding_num, training_data, ivf_centroids_storage, embedding_dimension(), centroids_num()); + const auto exclude_num_each_end = std::min((training_embedding_num / 200u) + 1u, training_embedding_num); + for (u32 i = 0; i < embedding_dimension(); ++i) { + std::priority_queue min_heap; // biggest on top + std::priority_queue, std::greater<>> max_heap; // smallest on top + const f32 *residual_data = residuals.get() + i; + for (u32 j = 0; j < exclude_num_each_end; ++j) { + const f32 x = *residual_data; + min_heap.push(x); + max_heap.push(x); + residual_data += embedding_dimension(); + } + for (u32 j = exclude_num_each_end; j < training_embedding_num; ++j) { + const f32 x = *residual_data; + if (x < min_heap.top()) { + min_heap.pop(); + min_heap.push(x); + } + if (x > max_heap.top()) { + max_heap.pop(); + max_heap.push(x); + } + residual_data += embedding_dimension(); + } + const f32 range_start = min_heap.top(); + const f32 range_end = max_heap.top(); + assert(range_start <= range_end); + const u32 range_parts = 1u << sq_bits_; + const f32 range_stride = (range_end - range_start) / range_parts; + common_vec_a_[i] = range_stride; + common_vec_b_[i] = range_start + 0.5f * range_stride; + } + } +}; + template <> class IVF_Parts_Storage_Info : public IVF_Parts_Storage { const u32 subspace_num_ = 0; @@ -109,8 +235,6 @@ class IVF_Parts_Storage_Info const u32 expect_subspace_centroid_num_ = 1u << subspace_centroid_bits_; u32 real_subspace_centroid_num_ = 0; - // (-0.5 * norm) for each centroid - UniquePtr centroid_norms_neg_half_ = MakeUniqueForOverwrite(centroids_num()); // centroids for each subspace, size: subspace_dimension_ * real_subspace_centroid_num_ * subspace_num_ UniquePtr subspace_centroids_data_ = {}; // size: real_subspace_centroid_num_ * subspace_num_ @@ -147,14 +271,12 @@ class IVF_Parts_Storage_Info void Save(LocalFileHandle &file_handle) const override { file_handle.Append(&real_subspace_centroid_num_, sizeof(real_subspace_centroid_num_)); - file_handle.Append(centroid_norms_neg_half_.get(), centroids_num() * sizeof(f32)); file_handle.Append(subspace_centroids_data_.get(), embedding_dimension() * real_subspace_centroid_num_ * sizeof(f32)); file_handle.Append(subspace_centroid_norms_neg_half_.get(), real_subspace_centroid_num_ * subspace_num_ * sizeof(f32)); } void Load(LocalFileHandle &file_handle) override { file_handle.Read(&real_subspace_centroid_num_, sizeof(real_subspace_centroid_num_)); - file_handle.Read(centroid_norms_neg_half_.get(), centroids_num() * sizeof(f32)); subspace_centroids_data_ = MakeUniqueForOverwrite(embedding_dimension() * real_subspace_centroid_num_); file_handle.Read(subspace_centroids_data_.get(), embedding_dimension() * real_subspace_centroid_num_ * sizeof(f32)); subspace_centroid_norms_neg_half_ = MakeUniqueForOverwrite(real_subspace_centroid_num_ * subspace_num_); @@ -162,43 +284,8 @@ class IVF_Parts_Storage_Info } void Train(const u32 training_embedding_num, const f32 *training_data, const IVF_Centroids_Storage *ivf_centroids_storage) final { - const f32 *centroids_data = ivf_centroids_storage->data(); - { - // prepare centroid_norms_neg_half_ - const f32 *centroids_data_ptr = centroids_data; - for (u32 i = 0; i < centroids_num(); ++i) { - centroid_norms_neg_half_[i] = -0.5f * L2NormSquare(centroids_data_ptr, embedding_dimension()); - centroids_data_ptr += embedding_dimension(); - } - } - const auto residuals = MakeUniqueForOverwrite(training_embedding_num * embedding_dimension()); - { - // distance: for every embedding, e * c - 0.5 * c^2, find max - const auto dist_table = MakeUniqueForOverwrite(training_embedding_num * centroids_num()); - matrixA_multiply_transpose_matrixB_output_to_C(training_data, - centroids_data, - training_embedding_num, - centroids_num(), - embedding_dimension(), - dist_table.get()); - for (u32 i = 0; i < training_embedding_num; ++i) { - const f32 *embedding_data_ptr = training_data + i * embedding_dimension(); - f32 *output_ptr = residuals.get() + i * embedding_dimension(); - f32 max_neg_distance = std::numeric_limits::lowest(); - u64 max_id = 0; - const f32 *dist_ptr = dist_table.get() + i * centroids_num(); - for (u32 k = 0; k < centroids_num(); ++k) { - if (const f32 neg_distance = dist_ptr[k] + centroid_norms_neg_half_[k]; neg_distance > max_neg_distance) { - max_neg_distance = neg_distance; - max_id = k; - } - } - const f32 *centroids_data_ptr = centroids_data + max_id * embedding_dimension(); - for (u32 j = 0; j < embedding_dimension(); ++j) { - output_ptr[j] = embedding_data_ptr[j] - centroids_data_ptr[j]; - } - } - } + const auto residuals = + GetTrainingResidual(training_embedding_num, training_data, ivf_centroids_storage, embedding_dimension(), centroids_num()); // train residuals TrainResidual(training_embedding_num, residuals.get()); } @@ -333,15 +420,18 @@ class IVF_Parts_Storage_T final : public IVF_Parts_Storage_Info { return ivf_part_storages_[part_id]->AppendOneEmbedding(embedding_ptr, segment_offset, ivf_centroids_storage, this); } - void SearchIndex(const u32 part_id, + void SearchIndex(const Vector &part_ids, const IVF_Index_Storage *ivf_index_storage, const KnnDistanceBase1 *knn_distance, const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, const std::function &add_result_func) const override { - return ivf_part_storages_[part_id] - ->SearchIndex(ivf_index_storage, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func); + SearchIndexPartsReuseContext context; + for (const auto part_id : part_ids) { + ivf_part_storages_[part_id] + ->SearchIndex(ivf_index_storage, knn_distance, query_ptr, query_element_type, satisfy_filter_func, add_result_func, context); + } } }; @@ -357,8 +447,7 @@ UniquePtr IVF_Parts_Storage::Make(const u32 embedding_dimensi return GetPartsStorageT.template operator()(); } case IndexIVFStorageOption::Type::kScalarQuantization: { - UnrecoverableError("Not implemented now."); - // return GetPartsStorageT.template operator()(); + return GetPartsStorageT.template operator()(); } case IndexIVFStorageOption::Type::kProductQuantization: { return GetPartsStorageT.template operator()(); @@ -425,7 +514,8 @@ class IVF_Part_Storage_Plain final : public IVF_Part_Storage { const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func) const override { + const std::function &add_result_func, + SearchIndexPartsReuseContext &) const override { auto ReturnT = [&] { if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf) || (query_element_type == src_embedding_data_type && @@ -479,6 +569,242 @@ class IVF_Part_Storage_Plain final : public IVF_Part_Storage { } }; +// SQ storage +template +class IVF_Part_Storage_SQ final : public IVF_Part_Storage { + static_assert(sq_bits == 4 || sq_bits == 8); + static constexpr u32 sq_encode_start = 0; + static constexpr u32 sq_encode_end = (1u << sq_bits) - 1u; + using ColumnEmbeddingElementT = EmbeddingDataTypeToCppTypeT; + static_assert(IsAnyOf); + + const u32 embedding_dimension_ = 0; + const u32 embedding_sq_bytes_ = (sq_bits == 8) ? embedding_dimension_ : ((embedding_dimension_ + 1) / 2); + Vector sq_data_{}; + Vector dot_v_e_{}; + +public: + IVF_Part_Storage_SQ(const u32 part_id, const u32 embedding_dimension) : IVF_Part_Storage(part_id), embedding_dimension_(embedding_dimension) {} + + void Save(LocalFileHandle &file_handle) const override { + IVF_Part_Storage::Save(file_handle); + const u32 element_cnt = embedding_num() * embedding_sq_bytes_; + assert(element_cnt == sq_data_.size()); + file_handle.Append(sq_data_.data(), element_cnt * sizeof(u8)); + assert(embedding_num() == dot_v_e_.size()); + file_handle.Append(dot_v_e_.data(), embedding_num() * sizeof(f32)); + } + + void Load(LocalFileHandle &file_handle) override { + IVF_Part_Storage::Load(file_handle); + const u32 element_cnt = embedding_num() * embedding_sq_bytes_; + sq_data_.resize(element_cnt); + file_handle.Read(sq_data_.data(), element_cnt * sizeof(u8)); + dot_v_e_.resize(embedding_num()); + file_handle.Read(dot_v_e_.data(), embedding_num() * sizeof(f32)); + } + + void AppendOneEmbedding(const void *embedding_ptr, + const SegmentOffset segment_offset, + const IVF_Centroids_Storage *ivf_centroids_storage, + const IVF_Parts_Storage *ivf_parts_storage) override { + const auto *src_embedding_data = static_cast(embedding_ptr); + const auto [src_embedding_f32, _] = GetF32Ptr(src_embedding_data, embedding_dimension_); + Vector encode_output(embedding_sq_bytes_); + Vector error_v(embedding_dimension_); + const auto *ivf_parts_storage_info = + dynamic_cast *>(ivf_parts_storage); + assert(ivf_parts_storage_info); + const auto centroid_data = ivf_centroids_storage->data() + part_id() * embedding_dimension_; + for (u32 i = 0; i < embedding_dimension_; ++i) { + const f32 residual_i = src_embedding_f32[i] - centroid_data[i]; + const f32 a = ivf_parts_storage_info->data_a()[i]; + const f32 b = ivf_parts_storage_info->data_b()[i]; + u32 n = {}; + if (const f32 n_f = std::round((residual_i - b) / (a > 0.0f ? a : 1.0f)); n_f < sq_encode_start) { + n = sq_encode_start; + } else if (n_f >= sq_encode_end) { + n = sq_encode_end; + } else { + n = static_cast(n_f); + } + if constexpr (sq_bits == 8) { + encode_output[i] = static_cast(n); + } else { + static_assert(sq_bits == 4); + encode_output[i / 2] |= (i & 1) ? (n << 4) : n; + } + error_v[i] = residual_i - (a * n + b); + } + const auto v_e_ip = IPDistance(src_embedding_f32, error_v.data(), embedding_dimension_); + dot_v_e_.push_back(v_e_ip); + sq_data_.insert(sq_data_.end(), encode_output.begin(), encode_output.end()); + embedding_segment_offsets_.push_back(segment_offset); + ++embedding_num_; + } + + void SearchIndex(const IVF_Index_Storage *ivf_index_storage, + const KnnDistanceBase1 *knn_distance, + const void *query_ptr, + const EmbeddingDataType query_element_type, + const std::function &satisfy_filter_func, + const std::function &add_result_func, + SearchIndexPartsReuseContext &context) const override { + auto ReturnT = [&] { + if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf) || + (query_element_type == src_embedding_data_type && + (query_element_type == EmbeddingDataType::kElemInt8 || query_element_type == EmbeddingDataType::kElemUInt8))) { + return SearchIndexT(ivf_index_storage, + knn_distance, + static_cast *>(query_ptr), + satisfy_filter_func, + add_result_func, + context); + } else { + UnrecoverableError("Invalid Query EmbeddingDataType"); + } + }; + switch (query_element_type) { + case EmbeddingDataType::kElemFloat: { + return ReturnT.template operator()(); + } + case EmbeddingDataType::kElemUInt8: { + return ReturnT.template operator()(); + } + case EmbeddingDataType::kElemInt8: { + return ReturnT.template operator()(); + } + default: { + UnrecoverableError("Invalid EmbeddingDataType"); + } + } + } + + static inline u32 sq_decode(const u8 *sq_data, const u32 pos) + requires(sq_bits == 8) + { + return sq_data[pos]; + } + + static inline u32 sq_decode(const u8 *sq_data, const u32 pos) + requires(sq_bits == 4) + { + return (pos & 1) ? (sq_data[pos / 2] >> 4) : (sq_data[pos / 2] & 0xf); + } + + template + void SearchIndexT(const IVF_Index_Storage *ivf_index_storage, + const KnnDistanceBase1 *knn_distance, + const EmbeddingDataTypeToCppTypeT *query_ptr, + const std::function &satisfy_filter_func, + const std::function &add_result_func, + SearchIndexPartsReuseContext &context) const { + using QueryDataType = EmbeddingDataTypeToCppTypeT; + static_assert(std::is_same_v); + auto knn_distance_1 = dynamic_cast *>(knn_distance); + if (!knn_distance_1) [[unlikely]] { + UnrecoverableError("Invalid KnnDistance1"); + } + const auto &ivf_parts_storage = + static_cast &>(ivf_index_storage->ivf_parts_storage()); + const auto dimension = embedding_dimension_; + const auto a_ptr = ivf_parts_storage.data_a(); + const auto b_ptr = ivf_parts_storage.data_b(); + const auto c_ptr = ivf_index_storage->ivf_centroids_storage().data() + part_id() * dimension; + const auto [x_ptr, _] = GetF32Ptr(query_ptr, dimension); + const auto total_embedding_num = embedding_num(); + switch (const KnnDistanceType dist_type = knn_distance_1->dist_type_; dist_type) { + case KnnDistanceType::kInnerProduct: { + // dot(x, v) = dot(x, c + b) + dot((x * a), n) + dot(v, e) + const auto c_plus_b = MakeUniqueForOverwrite(dimension); + const auto x_mult_a = MakeUniqueForOverwrite(dimension); + for (u32 i = 0; i < dimension; ++i) { + c_plus_b[i] = c_ptr[i] + b_ptr[i]; + x_mult_a[i] = x_ptr[i] * a_ptr[i]; + } + const auto dot_x_c_b = IPDistance(x_ptr, c_plus_b.get(), dimension); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + f32 d = dot_x_c_b + dot_v_e_[i]; + const u8 *sq_data = sq_data_.data() + i * embedding_sq_bytes_; + for (u32 j = 0; j < dimension; ++j) { + d += x_mult_a[j] * sq_decode(sq_data, j); + } + add_result_func(d, segment_offset); + } + break; + } + case KnnDistanceType::kCosine: { + // dot(x, v) = dot(x, c + b) + dot((x * a), n) + dot(v, e) + // dot(v, v) = dot(c + b, c + b) + dot(a^2, n^2) + 2 * dot ((c + b) * a, n) + 2 * dot(v, e) + const auto x_l2 = L2NormSquare(x_ptr, dimension); + const auto c_plus_b = MakeUniqueForOverwrite(dimension); + const auto x_mult_a = MakeUniqueForOverwrite(dimension); + const auto a_square = MakeUniqueForOverwrite(dimension); + const auto c_plus_b_mult_a_2 = MakeUniqueForOverwrite(dimension); + for (u32 i = 0; i < dimension; ++i) { + c_plus_b[i] = c_ptr[i] + b_ptr[i]; + x_mult_a[i] = x_ptr[i] * a_ptr[i]; + a_square[i] = a_ptr[i] * a_ptr[i]; + c_plus_b_mult_a_2[i] = 2.0f * c_plus_b[i] * a_ptr[i]; + } + const auto dot_x_c_b = IPDistance(x_ptr, c_plus_b.get(), dimension); + const auto c_plus_b_l2 = L2NormSquare(c_plus_b.get(), dimension); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + f32 x_v_ip = dot_x_c_b + dot_v_e_[i]; + f32 v_l2 = c_plus_b_l2 + 2.0f * dot_v_e_[i]; + const u8 *sq_data = sq_data_.data() + i * embedding_sq_bytes_; + for (u32 j = 0; j < dimension; ++j) { + const auto n = sq_decode(sq_data, j); + x_v_ip += x_mult_a[j] * n; + v_l2 += (a_square[j] * n + c_plus_b_mult_a_2[j]) * n; + } + const f32 d = x_v_ip / std::sqrt(x_l2 * v_l2); + add_result_func(d, segment_offset); + } + break; + } + case KnnDistanceType::kL2: { + // l2(a * n + b + c - x) = dot(a^2, n^2) + l2(b + c - x) + 2 * dot(a * (b + c - x), n) + const auto a_square = MakeUniqueForOverwrite(dimension); + const auto b_plus_c_minus_x = MakeUniqueForOverwrite(dimension); + const auto a_mult_b_plus_c_minus_x_2 = MakeUniqueForOverwrite(dimension); + for (u32 i = 0; i < dimension; ++i) { + a_square[i] = a_ptr[i] * a_ptr[i]; + b_plus_c_minus_x[i] = b_ptr[i] + c_ptr[i] - x_ptr[i]; + a_mult_b_plus_c_minus_x_2[i] = 2.0f * a_ptr[i] * b_plus_c_minus_x[i]; + } + const auto b_plus_c_minus_x_l2 = L2NormSquare(b_plus_c_minus_x.get(), dimension); + for (u32 i = 0; i < total_embedding_num; ++i) { + const auto segment_offset = embedding_segment_offset(i); + if (!satisfy_filter_func(segment_offset)) { + continue; + } + f32 d = b_plus_c_minus_x_l2; + const u8 *sq_data = sq_data_.data() + i * embedding_sq_bytes_; + for (u32 j = 0; j < dimension; ++j) { + const auto n = sq_decode(sq_data, j); + d += (a_square[j] * n + a_mult_b_plus_c_minus_x_2[j]) * n; + } + add_result_func(d, segment_offset); + } + break; + } + default: { + RecoverableError(Status::SyntaxError(fmt::format("IVFSQ does not support {} metric now.", KnnExpr::KnnDistanceType2Str(dist_type)))); + break; + } + } + } +}; + // PQ storage struct PQ_Code_Storage { const u64 subspace_num_ = 0; @@ -636,7 +962,7 @@ UniquePtr GetPQCodeStorage(const u32 subspace_num, const u32 su template class IVF_Part_Storage_PQ final : public IVF_Part_Storage { using ColumnEmbeddingElementT = EmbeddingDataTypeToCppTypeT; - static_assert(IsAnyOf); + static_assert(IsAnyOf); const u32 subspace_num_ = 0; const u32 subspace_bits_ = 0; @@ -685,7 +1011,8 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { const void *query_ptr, const EmbeddingDataType query_element_type, const std::function &satisfy_filter_func, - const std::function &add_result_func) const override { + const std::function &add_result_func, + SearchIndexPartsReuseContext &context) const override { auto ReturnT = [&] { if constexpr ((query_element_type == EmbeddingDataType::kElemFloat && IsAnyOf) || (query_element_type == src_embedding_data_type && @@ -694,7 +1021,8 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { knn_distance, static_cast *>(query_ptr), satisfy_filter_func, - add_result_func); + add_result_func, + context); } else { UnrecoverableError("Invalid Query EmbeddingDataType"); } @@ -720,8 +1048,10 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { const KnnDistanceBase1 *knn_distance, const EmbeddingDataTypeToCppTypeT *query_ptr, const std::function &satisfy_filter_func, - const std::function &add_result_func) const { + const std::function &add_result_func, + SearchIndexPartsReuseContext &context) const { using QueryDataType = EmbeddingDataTypeToCppTypeT; + static_assert(std::is_same_v); auto knn_distance_1 = dynamic_cast *>(knn_distance); if (!knn_distance_1) [[unlikely]] { UnrecoverableError("Invalid KnnDistance1"); @@ -733,13 +1063,15 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { const auto dimension = ivf_index_storage->embedding_dimension(); const auto [query_f32, _] = GetF32Ptr(query_ptr, dimension); const auto centroid_data = ivf_index_storage->ivf_centroids_storage().data() + part_id() * dimension; - const auto ip_func = GetSIMD_FUNCTIONS().IPDistance_func_ptr_; + const auto total_embedding_num = embedding_num(); switch (const KnnDistanceType dist_type = knn_distance_1->dist_type_; dist_type) { case KnnDistanceType::kInnerProduct: { - const auto query_centroid_ip = ip_func(query_f32, centroid_data, dimension); - const auto ip_table = ivf_parts_storage.GetIPTable(query_f32); + const auto query_centroid_ip = IPDistance(query_f32, centroid_data, dimension); + auto &ip_table = context.pq_query_ip_table_; + if (!ip_table) { + ip_table = ivf_parts_storage.GetIPTable(query_f32); + } const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); - const auto total_embedding_num = embedding_num(); for (u32 i = 0; i < total_embedding_num; ++i) { const auto segment_offset = embedding_segment_offset(i); if (!satisfy_filter_func(segment_offset)) { @@ -755,13 +1087,15 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { break; } case KnnDistanceType::kCosine: { - const auto query_l2 = L2NormSquare(query_f32, dimension); - const auto centroid_l2 = L2NormSquare(centroid_data, dimension); - const auto query_centroid_ip = ip_func(query_f32, centroid_data, dimension); - const auto query_ip_table = ivf_parts_storage.GetIPTable(query_f32); + const auto query_l2 = L2NormSquare(query_f32, dimension); + const auto centroid_l2 = L2NormSquare(centroid_data, dimension); + const auto query_centroid_ip = IPDistance(query_f32, centroid_data, dimension); + auto &query_ip_table = context.pq_query_ip_table_; + if (!query_ip_table) { + query_ip_table = ivf_parts_storage.GetIPTable(query_f32); + } const auto centroid_ip_table = ivf_parts_storage.GetIPTable(centroid_data); const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); - const auto total_embedding_num = embedding_num(); for (u32 i = 0; i < total_embedding_num; ++i) { const auto segment_offset = embedding_segment_offset(i); if (!satisfy_filter_func(segment_offset)) { @@ -771,9 +1105,9 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { f32 ip = query_centroid_ip; f32 target_l2 = centroid_l2 * 0.5f; for (u32 j = 0; j < subspace_num; ++j) { - ip += query_ip_table[j * real_subspace_centroid_num + encoded_codes[j]]; - target_l2 += centroid_ip_table[j * real_subspace_centroid_num + encoded_codes[j]] - - ivf_parts_storage.subspace_centroid_norms_neg_half_at_subspace(j)[encoded_codes[j]]; + const auto idx = j * real_subspace_centroid_num + encoded_codes[j]; + ip += query_ip_table[idx]; + target_l2 += centroid_ip_table[idx] - ivf_parts_storage.subspace_centroid_norms_neg_half_at_subspace(j)[encoded_codes[j]]; } target_l2 *= 2.0f; const auto d = ip / std::sqrt(query_l2 * target_l2); @@ -786,10 +1120,9 @@ class IVF_Part_Storage_PQ final : public IVF_Part_Storage { for (u32 i = 0; i < dimension; ++i) { residual_query[i] = query_f32[i] - centroid_data[i]; } - const auto residual_query_l2 = L2NormSquare(residual_query.get(), dimension); + const auto residual_query_l2 = L2NormSquare(residual_query.get(), dimension); const auto residual_ip_table = ivf_parts_storage.GetIPTable(residual_query.get()); const auto encoded_codes = MakeUniqueForOverwrite(subspace_num_); - const auto total_embedding_num = embedding_num(); for (u32 i = 0; i < total_embedding_num; ++i) { const auto segment_offset = embedding_segment_offset(i); if (!satisfy_filter_func(segment_offset)) { @@ -866,18 +1199,66 @@ UniquePtr IVF_Part_Storage::Make(const u32 part_id, break; } case IndexIVFStorageOption::Type::kScalarQuantization: { - UnrecoverableError("Not implemented"); + const auto sq_bits = ivf_storage_option.scalar_quantization_bits_; + auto GetSQResult = [part_id, embedding_dimension, sq_bits]() -> UniquePtr { + switch (sq_bits) { + case 4: { + return MakeUnique>(part_id, embedding_dimension); + } + case 8: { + return MakeUnique>(part_id, embedding_dimension); + } + default: { + UnrecoverableError(fmt::format("Invalid scalar quantization bits: {}", sq_bits)); + return {}; + } + } + }; + switch (embedding_data_type) { + case EmbeddingDataType::kElemDouble: { + return GetSQResult.template operator()(); + } + case EmbeddingDataType::kElemFloat: { + return GetSQResult.template operator()(); + } + case EmbeddingDataType::kElemFloat16: { + return GetSQResult.template operator()(); + } + case EmbeddingDataType::kElemBFloat16: { + return GetSQResult.template operator()(); + } + default: { + UnrecoverableError("Unsupported embedding data type for IVFSQ."); + return {}; + } + } break; } case IndexIVFStorageOption::Type::kProductQuantization: { const auto subspace_num = ivf_storage_option.product_quantization_subspace_num_; const auto subspace_bits = ivf_storage_option.product_quantization_subspace_bits_; - return ApplyEmbeddingDataTypeToFunc( - embedding_data_type, - [part_id, subspace_num, subspace_bits]() -> UniquePtr { - return MakeUnique>(part_id, subspace_num, subspace_bits); - }, - [] { return UniquePtr(); }); + auto GetPQResult = [part_id, subspace_num, subspace_bits]() { + return MakeUnique>(part_id, subspace_num, subspace_bits); + }; + switch (embedding_data_type) { + case EmbeddingDataType::kElemDouble: { + return GetPQResult.template operator()(); + } + case EmbeddingDataType::kElemFloat: { + return GetPQResult.template operator()(); + } + case EmbeddingDataType::kElemFloat16: { + return GetPQResult.template operator()(); + } + case EmbeddingDataType::kElemBFloat16: { + return GetPQResult.template operator()(); + } + default: { + UnrecoverableError("Unsupported embedding data type for IVFPQ."); + return {}; + } + } + break; } } return {};