From 7229c4c0f4eb586938c21fcb45bca40d179c0766 Mon Sep 17 00:00:00 2001 From: vsian Date: Wed, 16 Oct 2024 19:40:31 +0800 Subject: [PATCH 1/6] add slt tests for embedding(bit, *) --- test/sql/dml/insert/test_insert_embedding.slt | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/sql/dml/insert/test_insert_embedding.slt b/test/sql/dml/insert/test_insert_embedding.slt index 886ae64085..6f9a23df2c 100644 --- a/test/sql/dml/insert/test_insert_embedding.slt +++ b/test/sql/dml/insert/test_insert_embedding.slt @@ -45,4 +45,16 @@ SELECT * FROM sqllogic_test_insert_embedding; 0 [0.1,1.1,2.1,3.1,4.1,5.1,6.1,7.1,8.1,9.1,10.1,11.1,12.1,13.1,14.1,15.1] statement ok -DROP TABLE sqllogic_test_insert_embedding; \ No newline at end of file +DROP TABLE sqllogic_test_insert_embedding; + +statement ok +CREATE TABLE sqllogic_test_insert_embedding (col1 INT, col2 EMBEDDING(BIT, 8)); + +query I +INSERT INTO sqllogic_test_insert_embedding VALUES(0, [0.0, 0.1, 0.2, 0.3, 0.3, 0.2, 0.1, 0.0]); +---- + +query II +SELECT * FROM sqllogic_test_insert_embedding; +---- +0 [01111110] From deb0377e974b19d23286f0b93eb86b3a3fa0c1d1 Mon Sep 17 00:00:00 2001 From: vsian Date: Fri, 18 Oct 2024 19:51:55 +0800 Subject: [PATCH 2/6] supporting hamming distance with binary vectors --- src/common/simd/distance_simd_functions.cpp | 15 ++++++++++++++- src/common/simd/distance_simd_functions.cppm | 2 ++ src/common/simd/simd_functions.cppm | 1 + src/common/simd/simd_init.cpp | 4 ++++ src/common/simd/simd_init.cppm | 5 +++++ .../physical_scan/physical_knn_scan.cpp | 17 ++++++++++++++--- src/function/table/knn_scan_data.cpp | 14 ++++++++++++-- src/scheduler/fragment_context.cpp | 6 +----- 8 files changed, 53 insertions(+), 11 deletions(-) diff --git a/src/common/simd/distance_simd_functions.cpp b/src/common/simd/distance_simd_functions.cpp index f7ec0b9480..5545037828 100644 --- a/src/common/simd/distance_simd_functions.cpp +++ b/src/common/simd/distance_simd_functions.cpp @@ -14,8 +14,8 @@ module; -#include #include "simd_common_intrin_include.h" +#include /* #if defined(__x86_64__) && (defined(__clang_major__) && (__clang_major__ > 10)) @@ -76,6 +76,19 @@ f32 CosineDistance_common(const f32 *x, const f32 *y, SizeT d) { return dot ? dot / sqrt(sqr_x * sqr_y) : 0.0f; } +f32 HammingDistance_common(const u8 *x, const u8 *y, SizeT d) { + SizeT real_d = d / 8; + f32 result = 0; + for (SizeT i = 0; i < real_d; ++i) { + u8 xor_result = x[i] ^ y[i]; + while (xor_result) { + result += (xor_result | 1); + xor_result >>= 1; + } + } + return result; +} + #if defined(__AVX2__) inline f32 L2Distance_avx2_128(const f32 *vector1, const f32 *vector2, SizeT) { __m256 diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); diff --git a/src/common/simd/distance_simd_functions.cppm b/src/common/simd/distance_simd_functions.cppm index 07b4c7bfdd..be977136d2 100644 --- a/src/common/simd/distance_simd_functions.cppm +++ b/src/common/simd/distance_simd_functions.cppm @@ -26,6 +26,8 @@ export f32 IPDistance_common(const f32 *x, const f32 *y, SizeT d); export f32 CosineDistance_common(const f32 *x, const f32 *y, SizeT d); +export f32 HammingDistance_common(const u8 *x, const u8 *y, SizeT d); + #if defined(__AVX2__) export f32 L2Distance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension); diff --git a/src/common/simd/simd_functions.cppm b/src/common/simd/simd_functions.cppm index bbf6407b61..9fd9a0de0b 100644 --- a/src/common/simd/simd_functions.cppm +++ b/src/common/simd/simd_functions.cppm @@ -25,6 +25,7 @@ export struct SIMD_FUNCTIONS { F32DistanceFuncType L2Distance_func_ptr_ = GetL2DistanceFuncPtr(); F32DistanceFuncType IPDistance_func_ptr_ = GetIPDistanceFuncPtr(); F32DistanceFuncType CosineDistance_func_ptr_ = GetCosineDistanceFuncPtr(); + U8HammingDistanceFuncType HammingDistance_func_ptr_ = GetHammingDistanceFuncPtr(); // HNSW F32 F32DistanceFuncType HNSW_F32L2_ptr_ = Get_HNSW_F32L2_ptr(); diff --git a/src/common/simd/simd_init.cpp b/src/common/simd/simd_init.cpp index 6babbe274f..6cef202d9e 100644 --- a/src/common/simd/simd_init.cpp +++ b/src/common/simd/simd_init.cpp @@ -66,6 +66,10 @@ F32DistanceFuncType GetCosineDistanceFuncPtr() { return &CosineDistance_common; } + U8HammingDistanceFuncType GetHammingDistanceFuncPtr() { + return &HammingDistance_common; +} + F32DistanceFuncType Get_HNSW_F32L2_16_ptr() { #if defined(__AVX512F__) if (IsAVX512Supported()) { diff --git a/src/common/simd/simd_init.cppm b/src/common/simd/simd_init.cppm index 957df9cd08..01dd63ce89 100644 --- a/src/common/simd/simd_init.cppm +++ b/src/common/simd/simd_init.cppm @@ -31,6 +31,7 @@ export using F32DistanceFuncType = f32(*)(const f32 *, const f32 *, SizeT); export using I8DistanceFuncType = i32(*)(const i8 *, const i8 *, SizeT); export using I8CosDistanceFuncType = f32(*)(const i8 *, const i8 *, SizeT); export using U8DistanceFuncType = i32(*)(const u8 *, const u8 *, SizeT); +export using U8HammingDistanceFuncType = f32(*)(const u8 *, const u8 *, SizeT); export using U8CosDistanceFuncType = f32(*)(const u8 *, const u8 *, SizeT); export using MaxSimF32BitIPFuncType = f32(*)(const f32 *, const u8 *, SizeT); export using MaxSimI32BitIPFuncType = i32(*)(const i32 *, const u8 *, SizeT); @@ -42,6 +43,10 @@ export using SearchTop1WithDisF32U32FuncType = void(*)(u32, u32, const f32 *, u3 export F32DistanceFuncType GetL2DistanceFuncPtr(); export F32DistanceFuncType GetIPDistanceFuncPtr(); export F32DistanceFuncType GetCosineDistanceFuncPtr(); + +// u32 distance functions +export U8HammingDistanceFuncType GetHammingDistanceFuncPtr(); + // HNSW F32 export F32DistanceFuncType Get_HNSW_F32L2_ptr(); export F32DistanceFuncType Get_HNSW_F32L2_16_ptr(); diff --git a/src/executor/operator/physical_scan/physical_knn_scan.cpp b/src/executor/operator/physical_scan/physical_knn_scan.cpp index 4243cb824b..9e964a7ca0 100644 --- a/src/executor/operator/physical_scan/physical_knn_scan.cpp +++ b/src/executor/operator/physical_scan/physical_knn_scan.cpp @@ -73,6 +73,7 @@ auto GetKnnExprForCalculation(const KnnExpression &src_knn_expr, const Embedding const auto src_query_embedding_type = src_knn_expr.embedding_data_type_; EmbeddingDataType new_query_embedding_type = EmbeddingDataType::kElemInvalid; switch (column_embedding_type) { + case EmbeddingDataType::kElemBit: case EmbeddingDataType::kElemUInt8: case EmbeddingDataType::kElemInt8: { // expect query embedding to be the same type @@ -86,7 +87,6 @@ auto GetKnnExprForCalculation(const KnnExpression &src_knn_expr, const Embedding // no need for alignment break; } - case EmbeddingDataType::kElemBit: case EmbeddingDataType::kElemInt16: case EmbeddingDataType::kElemInt32: case EmbeddingDataType::kElemInt64: { @@ -228,7 +228,9 @@ void PhysicalKnnScan::ExecuteInternalByColumnLogicalType(QueryContext *query_con case EmbeddingDataType::kElemBFloat16: { return ExecuteInternalByColumnDataType(query_context, knn_scan_operator_state); } - case EmbeddingDataType::kElemBit: + case EmbeddingDataType::kElemBit: { + return ExecuteInternalByColumnDataType(query_context, knn_scan_operator_state); + } case EmbeddingDataType::kElemInt16: case EmbeddingDataType::kElemInt32: case EmbeddingDataType::kElemInt64: @@ -322,7 +324,16 @@ void PhysicalKnnScan::ExecuteInternalByColumnDataType(QueryContext *query_contex UnrecoverableError(fmt::format("BUG: Query embedding data type: {} should be cast to Float before knn search!", EmbeddingT::EmbeddingDataType2String(query_elem_type))); } - case EmbeddingDataType::kElemBit: + case EmbeddingDataType::kElemBit: { + switch (query_dist_type) { + case KnnDistanceType::kHamming: { + return ExecuteDispatchHelper::Execute(this, query_context, knn_scan_operator_state); + } + default: { + return knn_distance_error(); + } + } + } case EmbeddingDataType::kElemInt16: case EmbeddingDataType::kElemInt32: case EmbeddingDataType::kElemInt64: diff --git a/src/function/table/knn_scan_data.cpp b/src/function/table/knn_scan_data.cpp index 4d04611cf5..8847a06c1e 100644 --- a/src/function/table/knn_scan_data.cpp +++ b/src/function/table/knn_scan_data.cpp @@ -94,6 +94,10 @@ KnnDistance1::KnnDistance1(KnnDistanceType dist_type) { dist_func_ = &hnsw_u8ip_f32_wrapper; break; } + case KnnDistanceType::kHamming: { + dist_func_ = GetSIMD_FUNCTIONS().HammingDistance_func_ptr_; + break; + } default: { Status status = Status::NotSupport(fmt::format("KnnDistanceType: {} is not support.", (i32)dist_type)); RecoverableError(status); @@ -161,6 +165,10 @@ KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData *shared_data, u32 cur Init(); break; } + case EmbeddingDataType::kElemBit: { + Init(); + break; + } default: { Status status = Status::NotSupport(fmt::format("Query EmbeddingDataType: {} is not support.", EmbeddingType::EmbeddingDataType2String(knn_scan_shared_data_->query_elem_type_))); @@ -178,14 +186,16 @@ void KnnScanFunctionData::Init() { } case KnnDistanceType::kL2: case KnnDistanceType::kHamming: { - auto merge_knn_max = MakeUnique>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_); + auto merge_knn_max = + MakeUnique>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_); merge_knn_max->Begin(); merge_knn_base_ = std::move(merge_knn_max); break; } case KnnDistanceType::kCosine: case KnnDistanceType::kInnerProduct: { - auto merge_knn_min = MakeUnique>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_); + auto merge_knn_min = + MakeUnique>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_); merge_knn_min->Begin(); merge_knn_base_ = std::move(merge_knn_min); break; diff --git a/src/scheduler/fragment_context.cpp b/src/scheduler/fragment_context.cpp index c25ddcbdc9..36ea3afdf1 100644 --- a/src/scheduler/fragment_context.cpp +++ b/src/scheduler/fragment_context.cpp @@ -480,9 +480,6 @@ MakeTaskState(SizeT operator_id, const Vector &physical_ops, case PhysicalOperatorType::kAlter: { return MakeTaskStateTemplate(physical_ops[operator_id]); } - case PhysicalOperatorType::kReadCache: { - return MakeTaskStateTemplate(physical_ops[operator_id]); - } default: { String error_message = fmt::format("Not support {} now", PhysicalOperatorToString(physical_ops[operator_id]->operator_type())); UnrecoverableError(error_message); @@ -1008,8 +1005,7 @@ void FragmentContext::MakeSourceState(i64 parallel_count) { case PhysicalOperatorType::kOptimize: case PhysicalOperatorType::kFlush: case PhysicalOperatorType::kCompactFinish: - case PhysicalOperatorType::kCompactIndexPrepare: - case PhysicalOperatorType::kReadCache: { + case PhysicalOperatorType::kCompactIndexPrepare: { if (fragment_type_ != FragmentType::kSerialMaterialize) { UnrecoverableError( fmt::format("{} should in serial materialized fragment", PhysicalOperatorToString(first_operator->operator_type()))); From 146a709ab0c39f91591e8b992083c6855c3b1b15 Mon Sep 17 00:00:00 2001 From: vsian Date: Fri, 18 Oct 2024 20:00:38 +0800 Subject: [PATCH 3/6] revert some change --- src/scheduler/fragment_context.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/scheduler/fragment_context.cpp b/src/scheduler/fragment_context.cpp index 36ea3afdf1..c25ddcbdc9 100644 --- a/src/scheduler/fragment_context.cpp +++ b/src/scheduler/fragment_context.cpp @@ -480,6 +480,9 @@ MakeTaskState(SizeT operator_id, const Vector &physical_ops, case PhysicalOperatorType::kAlter: { return MakeTaskStateTemplate(physical_ops[operator_id]); } + case PhysicalOperatorType::kReadCache: { + return MakeTaskStateTemplate(physical_ops[operator_id]); + } default: { String error_message = fmt::format("Not support {} now", PhysicalOperatorToString(physical_ops[operator_id]->operator_type())); UnrecoverableError(error_message); @@ -1005,7 +1008,8 @@ void FragmentContext::MakeSourceState(i64 parallel_count) { case PhysicalOperatorType::kOptimize: case PhysicalOperatorType::kFlush: case PhysicalOperatorType::kCompactFinish: - case PhysicalOperatorType::kCompactIndexPrepare: { + case PhysicalOperatorType::kCompactIndexPrepare: + case PhysicalOperatorType::kReadCache: { if (fragment_type_ != FragmentType::kSerialMaterialize) { UnrecoverableError( fmt::format("{} should in serial materialized fragment", PhysicalOperatorToString(first_operator->operator_type()))); From 0d7cc5d4ac33d726ebd33ce50ae12c70e3b729bd Mon Sep 17 00:00:00 2001 From: vsian Date: Mon, 21 Oct 2024 15:38:16 +0800 Subject: [PATCH 4/6] supports binary vector(embedding(bit...))) search with hamming distance --- src/common/simd/distance_simd_functions.cpp | 5 ++- .../physical_scan/physical_knn_scan.cpp | 11 ++++--- src/parser/expr/knn_expr.cpp | 4 +-- .../knn/embedding/test_knn_binary_hamming.slt | 31 +++++++++++++++++++ 4 files changed, 42 insertions(+), 9 deletions(-) create mode 100644 test/sql/dql/knn/embedding/test_knn_binary_hamming.slt diff --git a/src/common/simd/distance_simd_functions.cpp b/src/common/simd/distance_simd_functions.cpp index 5545037828..f97d0b7560 100644 --- a/src/common/simd/distance_simd_functions.cpp +++ b/src/common/simd/distance_simd_functions.cpp @@ -77,12 +77,11 @@ f32 CosineDistance_common(const f32 *x, const f32 *y, SizeT d) { } f32 HammingDistance_common(const u8 *x, const u8 *y, SizeT d) { - SizeT real_d = d / 8; f32 result = 0; - for (SizeT i = 0; i < real_d; ++i) { + for (SizeT i = 0; i < d; ++i) { u8 xor_result = x[i] ^ y[i]; while (xor_result) { - result += (xor_result | 1); + result += (xor_result & 1); xor_result >>= 1; } } diff --git a/src/executor/operator/physical_scan/physical_knn_scan.cpp b/src/executor/operator/physical_scan/physical_knn_scan.cpp index 9e964a7ca0..367817d229 100644 --- a/src/executor/operator/physical_scan/physical_knn_scan.cpp +++ b/src/executor/operator/physical_scan/physical_knn_scan.cpp @@ -425,8 +425,7 @@ void PhysicalKnnScan::PlanWithIndex(QueryContext *query_context) { // TODO: retu RecoverableError(std::move(error_status)); } // check index type - if (auto index_type = table_index_entry->index_base()->index_type_; - index_type != IndexType::kIVF and index_type != IndexType::kHnsw) { + if (auto index_type = table_index_entry->index_base()->index_type_; index_type != IndexType::kIVF and index_type != IndexType::kHnsw) { LOG_ERROR("Invalid index type"); Status error_status = Status::InvalidIndexType("invalid index"); RecoverableError(std::move(error_status)); @@ -820,7 +819,12 @@ struct BruteForceBlockScanSearch(knn_query_ptr, target_ptr, embedding_dim, dist_func->dist_func_, row_count, segment_id, block_id, bitmask); + auto embedding_info = static_cast(column_vector.data_type()->type_info().get()); + if (embedding_info->Type() == EmbeddingDataType::kElemBit) { + merge_heap->Search(knn_query_ptr, target_ptr, embedding_dim / 8, dist_func->dist_func_, row_count, segment_id, block_id, bitmask); + } else { + merge_heap->Search(knn_query_ptr, target_ptr, embedding_dim, dist_func->dist_func_, row_count, segment_id, block_id, bitmask); + } } }; @@ -891,5 +895,4 @@ void MultiVectorSearchOneLine(MergeKnn *merg merge_heap->Search(0, &result_dist, &db_row_id, 1); } - } // namespace infinity diff --git a/src/parser/expr/knn_expr.cpp b/src/parser/expr/knn_expr.cpp index 0c78a6cb4d..a91c6926ce 100644 --- a/src/parser/expr/knn_expr.cpp +++ b/src/parser/expr/knn_expr.cpp @@ -23,8 +23,8 @@ KnnExpr::~KnnExpr() { column_expr_ = nullptr; } - if(opt_params_ != nullptr) { - for(auto* param_ptr: *opt_params_) { + if (opt_params_ != nullptr) { + for (auto *param_ptr : *opt_params_) { delete param_ptr; param_ptr = nullptr; } diff --git a/test/sql/dql/knn/embedding/test_knn_binary_hamming.slt b/test/sql/dql/knn/embedding/test_knn_binary_hamming.slt new file mode 100644 index 0000000000..20c4b48b71 --- /dev/null +++ b/test/sql/dql/knn/embedding/test_knn_binary_hamming.slt @@ -0,0 +1,31 @@ +statement ok +DROP TABLE IF EXISTS test_binary_hamming; + +statement ok +CREATE TABLE test_binary_hamming(c1 INT, c2 EMBEDDING(BIT, 16)); + +query I +INSERT INTO test_binary_hamming VALUES +(0, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), +(1, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), +(2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2]), +(3, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3]), +(4, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4]), +(5, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5]); + +query II +SELECT * FROM test_binary_hamming; +---- +0 [0000000000000000] +1 [0000000000000001] +2 [0000000000000011] +3 [0000000000000111] +4 [0000000000001111] +5 [0000000000011111] + +query IF +SELECT c1, DISTANCE() FROM test_binary_hamming SEARCH MATCH VECTOR(c2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'bit', 'hamming', 3); +---- +0 0.000000 +1 1.000000 +2 2.000000 From cc4fd95ee0504b5b3d2a25d2c5179690d336f0c5 Mon Sep 17 00:00:00 2001 From: vsian Date: Tue, 22 Oct 2024 11:15:10 +0800 Subject: [PATCH 5/6] add simd support --- src/common/simd/distance_simd_functions.cpp | 47 +++++++++- src/common/simd/distance_simd_functions.cppm | 6 ++ src/common/simd/simd_common_tools.cppm | 77 +++++++++++----- src/common/simd/simd_init.cpp | 8 +- src/common/simd/simd_init.cppm | 1 + .../knn/embedding/test_knn_binary_hamming.slt | 88 +++++++++++++++++++ 6 files changed, 201 insertions(+), 26 deletions(-) diff --git a/src/common/simd/distance_simd_functions.cpp b/src/common/simd/distance_simd_functions.cpp index f97d0b7560..52f5bc4a01 100644 --- a/src/common/simd/distance_simd_functions.cpp +++ b/src/common/simd/distance_simd_functions.cpp @@ -80,14 +80,53 @@ f32 HammingDistance_common(const u8 *x, const u8 *y, SizeT d) { f32 result = 0; for (SizeT i = 0; i < d; ++i) { u8 xor_result = x[i] ^ y[i]; - while (xor_result) { - result += (xor_result & 1); - xor_result >>= 1; - } + result += __builtin_popcount(xor_result); + } + return result; +} + +#if defined(__AVX2__) + +f32 HammingDistance_avx2(const u8 *x, const u8 *y, SizeT d) { + f32 result = 0; + SizeT pos = 0; + // 8 * 32 = 256 + for (; pos + 32 < d; pos += 32) { + __m256i xor_result = + _mm256_xor_si256(_mm256_loadu_si256(reinterpret_cast(x)), _mm256_loadu_si256(reinterpret_cast(y))); + result += popcount_avx2(xor_result); + x += 32; + y += 32; + } + if (pos < d) { + result += HammingDistance_common(x, y, d - pos); } return result; } +#endif // defined (__AVX2__) + +#if defined(__SSE2__) + +f32 HammingDistance_sse2(const u8 *x, const u8 *y, SizeT d) { + f32 result = 0; + SizeT pos = 0; + // 8 * 16 = 128 + for (; pos + 16 < d; pos += 16) { + __m128i xor_result = + _mm_xor_si128(_mm_loadu_si128(reinterpret_cast(x)), _mm_loadu_si128(reinterpret_cast(y))); + result += popcount_sse2(xor_result); + x += 16; + y += 16; + } + if (pos < d) { + result += HammingDistance_common(x, y, d - pos); + } + return result; +} + +#endif // defined (__SSE2__) + #if defined(__AVX2__) inline f32 L2Distance_avx2_128(const f32 *vector1, const f32 *vector2, SizeT) { __m256 diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2)); diff --git a/src/common/simd/distance_simd_functions.cppm b/src/common/simd/distance_simd_functions.cppm index be977136d2..112fff9a6b 100644 --- a/src/common/simd/distance_simd_functions.cppm +++ b/src/common/simd/distance_simd_functions.cppm @@ -34,6 +34,12 @@ export f32 L2Distance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimensi export f32 IPDistance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension); export f32 CosineDistance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension); + +export f32 HammingDistance_avx2(const u8 *vector1, const u8 *vector2, SizeT dimension); +#endif + +#if defined(__SSE2__) +export f32 HammingDistance_sse2(const u8 *vector1, const u8 *vector2, SizeT dimesion); #endif } // namespace infinity diff --git a/src/common/simd/simd_common_tools.cppm b/src/common/simd/simd_common_tools.cppm index bd5f9a31e6..2923876616 100644 --- a/src/common/simd/simd_common_tools.cppm +++ b/src/common/simd/simd_common_tools.cppm @@ -26,12 +26,12 @@ export U8MaskPtr GetU8MasksForAVX2(); #ifdef __SSE__ // https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction/35270026#35270026 -export inline float hsum_ps_sse1(__m128 v) { // v = [ D C | B A ] - __m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ] - __m128 sums = _mm_add_ps(v, shuf); // sums = [ D+C C+D | B+A A+B ] - shuf = _mm_movehl_ps(shuf, sums); // [ C D | D+C C+D ] // let the compiler avoid a mov by reusing shuf - sums = _mm_add_ss(sums, shuf); - return _mm_cvtss_f32(sums); +export inline float hsum_ps_sse1(__m128 v) { // v = [ D C | B A ] + __m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ] + __m128 sums = _mm_add_ps(v, shuf); // sums = [ D+C C+D | B+A A+B ] + shuf = _mm_movehl_ps(shuf, sums); // [ C D | D+C C+D ] // let the compiler avoid a mov by reusing shuf + sums = _mm_add_ss(sums, shuf); + return _mm_cvtss_f32(sums); } #endif @@ -60,41 +60,37 @@ export inline float hsum256_ps_avx(__m256 v) { #ifdef __SSE2__ // https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction/35270026#35270026 export int hsum_epi32_sse2(__m128i x) { - __m128i hi64 = _mm_shuffle_epi32(x, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i hi64 = _mm_shuffle_epi32(x, _MM_SHUFFLE(1, 0, 3, 2)); __m128i sum64 = _mm_add_epi32(hi64, x); - __m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2)); // Swap the low two elements + __m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2)); // Swap the low two elements __m128i sum32 = _mm_add_epi32(sum64, hi32); - return _mm_cvtsi128_si32(sum32); // SSE2 movd + return _mm_cvtsi128_si32(sum32); // SSE2 movd } #endif #ifdef __AVX__ // https://stackoverflow.com/questions/60108658/fastest-method-to-calculate-sum-of-all-packed-32-bit-integers-using-avx512-or-av/60109639#60109639 -export int hsum_epi32_avx(__m128i x) -{ - __m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a movdqa +export int hsum_epi32_avx(__m128i x) { + __m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a movdqa __m128i sum64 = _mm_add_epi32(hi64, x); - __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); // Swap the low two elements + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); // Swap the low two elements __m128i sum32 = _mm_add_epi32(sum64, hi32); - return _mm_cvtsi128_si32(sum32); // movd + return _mm_cvtsi128_si32(sum32); // movd } #endif #ifdef __AVX2__ // https://stackoverflow.com/questions/60108658/fastest-method-to-calculate-sum-of-all-packed-32-bit-integers-using-avx512-or-av/60109639#60109639 // only needs AVX2 -export int hsum_8x32_avx2(__m256i v) -{ - __m128i sum128 = _mm_add_epi32( - _mm256_castsi256_si128(v), - _mm256_extracti128_si256(v, 1)); // silly GCC uses a longer AXV512VL instruction if AVX512 is enabled :/ +export int hsum_8x32_avx2(__m256i v) { + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(v), + _mm256_extracti128_si256(v, 1)); // silly GCC uses a longer AXV512VL instruction if AVX512 is enabled :/ return hsum_epi32_avx(sum128); } #endif #ifdef __AVX512F__ -export int hsum_epi32_avx512(__m512i v) -{ +export int hsum_epi32_avx512(__m512i v) { __m256i lo = _mm512_castsi512_si256(v); __m256i hi = _mm512_extracti64x4_epi64(v, 1); __m256i sum = _mm256_add_epi32(lo, hi); @@ -154,4 +150,43 @@ export inline __m512i abs_sub_epu8_avx512(const __m512i a, const __m512i b) { } #endif +#if defined(__AVX2__) +export inline int popcount_avx2(const __m256i v) { + const __m256i m1 = _mm256_set1_epi8(0x55); + const __m256i m2 = _mm256_set1_epi8(0x33); + const __m256i m4 = _mm256_set1_epi8(0x0F); + + const __m256i t1 = _mm256_sub_epi8(v, (_mm256_srli_epi16(v, 1) & m1)); + const __m256i t2 = _mm256_add_epi8(t1 & m2, (_mm256_srli_epi16(t1, 2) & m2)); + const __m256i t3 = _mm256_add_epi8(t2, _mm256_srli_epi16(t2, 4)) & m4; + __m256i sad = _mm256_sad_epu8(t3, _mm256_setzero_si256()); + + __m128i sad_low = _mm256_extracti128_si256(sad, 0); + __m128i sad_high = _mm256_extracti128_si256(sad, 1); + + sad_low = _mm_add_epi64(sad_low, sad_high); + + int result = _mm_extract_epi64(sad_low, 0) + _mm_extract_epi64(sad_low, 1); + return result; +} +#endif // defined (__AVX2__) + +#if defined(__SSE2__) +export inline int popcount_sse2(const __m128i x) { + const __m128i m1 = _mm_set1_epi8(0x55); + const __m128i m2 = _mm_set1_epi8(0x33); + const __m128i m4 = _mm_set1_epi8(0x0F); + + const __m128i t1 = x; + const __m128i t2 = _mm_sub_epi8(t1, _mm_srli_epi16(t1, 1) & m1); + const __m128i t3 = _mm_add_epi8(t2 & m2, _mm_srli_epi16(t2, 2) & m2); + const __m128i t4 = _mm_add_epi8(t3, _mm_srli_epi16(t3, 4)) & m4; + + __m128i sad = _mm_sad_epu8(t4, _mm_setzero_si128()); + + int result = _mm_extract_epi64(sad, 0) + _mm_extract_epi64(sad, 1); + return result; +} +#endif // defined (__SSE2__) + } // namespace infinity diff --git a/src/common/simd/simd_init.cpp b/src/common/simd/simd_init.cpp index 6cef202d9e..68e6d7fb3d 100644 --- a/src/common/simd/simd_init.cpp +++ b/src/common/simd/simd_init.cpp @@ -66,7 +66,13 @@ F32DistanceFuncType GetCosineDistanceFuncPtr() { return &CosineDistance_common; } - U8HammingDistanceFuncType GetHammingDistanceFuncPtr() { +U8HammingDistanceFuncType GetHammingDistanceFuncPtr() { +#ifdef __AVX2__ + return &HammingDistance_avx2; +#endif +#ifdef __SSE2__ + return &HammingDistance_sse2; +#endif return &HammingDistance_common; } diff --git a/src/common/simd/simd_init.cppm b/src/common/simd/simd_init.cppm index 01dd63ce89..f933775523 100644 --- a/src/common/simd/simd_init.cppm +++ b/src/common/simd/simd_init.cppm @@ -31,6 +31,7 @@ export using F32DistanceFuncType = f32(*)(const f32 *, const f32 *, SizeT); export using I8DistanceFuncType = i32(*)(const i8 *, const i8 *, SizeT); export using I8CosDistanceFuncType = f32(*)(const i8 *, const i8 *, SizeT); export using U8DistanceFuncType = i32(*)(const u8 *, const u8 *, SizeT); +//dimension in hamming distance is in bytes export using U8HammingDistanceFuncType = f32(*)(const u8 *, const u8 *, SizeT); export using U8CosDistanceFuncType = f32(*)(const u8 *, const u8 *, SizeT); export using MaxSimF32BitIPFuncType = f32(*)(const f32 *, const u8 *, SizeT); diff --git a/test/sql/dql/knn/embedding/test_knn_binary_hamming.slt b/test/sql/dql/knn/embedding/test_knn_binary_hamming.slt index 20c4b48b71..be31c42a43 100644 --- a/test/sql/dql/knn/embedding/test_knn_binary_hamming.slt +++ b/test/sql/dql/knn/embedding/test_knn_binary_hamming.slt @@ -29,3 +29,91 @@ SELECT c1, DISTANCE() FROM test_binary_hamming SEARCH MATCH VECTOR(c2, [0, 0, 0, 0 0.000000 1 1.000000 2 2.000000 + +statement ok +DROP TABLE IF EXISTS test_binary_hamming; + +statement ok +CREATE TABLE test_binary_hamming(c1 INT, c2 EMBEDDING(BIT, 264)); + +query I +INSERT INTO test_binary_hamming VALUES +( + 0, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0 + ] +), +( + 1, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1 + ] +), +( + 2, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 2 + ] +), +( + 3, + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, 2, 3 + ] +); + + +query IF +SELECT c1, DISTANCE() FROM test_binary_hamming SEARCH MATCH VECTOR(c2, +[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0 +], +'bit', 'hamming', 3); +---- +0 0.000000 +1 1.000000 +2 2.000000 + +statement ok +DROP TABLE IF EXISTS test_binary_hamming; From 73da75e0d4a08490ff0b47fcc8b3f5cbfe32ac81 Mon Sep 17 00:00:00 2001 From: vsian Date: Tue, 22 Oct 2024 11:30:39 +0800 Subject: [PATCH 6/6] add github link --- src/common/simd/simd_common_tools.cppm | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/common/simd/simd_common_tools.cppm b/src/common/simd/simd_common_tools.cppm index 2923876616..4c77fa5efc 100644 --- a/src/common/simd/simd_common_tools.cppm +++ b/src/common/simd/simd_common_tools.cppm @@ -150,6 +150,7 @@ export inline __m512i abs_sub_epu8_avx512(const __m512i a, const __m512i b) { } #endif +// https://github.com/WojciechMula/sse-popcount/blob/master/popcnt-avx2-harley-seal.cpp #if defined(__AVX2__) export inline int popcount_avx2(const __m256i v) { const __m256i m1 = _mm256_set1_epi8(0x55); @@ -171,6 +172,7 @@ export inline int popcount_avx2(const __m256i v) { } #endif // defined (__AVX2__) +// https://github.com/WojciechMula/sse-popcount/blob/master/popcnt-sse-harley-seal.cpp #if defined(__SSE2__) export inline int popcount_sse2(const __m128i x) { const __m128i m1 = _mm_set1_epi8(0x55);