Skip to content

Commit

Permalink
Supports Binary vector with Hamming distance. (#2070)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Support hamming distance as metric when matching binary vectors.

Issue link: #2069 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Test cases
  • Loading branch information
vsian authored Oct 22, 2024
1 parent 115c9da commit 9a17553
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 34 deletions.
53 changes: 52 additions & 1 deletion src/common/simd/distance_simd_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

module;

#include <cmath>
#include "simd_common_intrin_include.h"
#include <cmath>

/*
#if defined(__x86_64__) && (defined(__clang_major__) && (__clang_major__ > 10))
Expand Down Expand Up @@ -76,6 +76,57 @@ f32 CosineDistance_common(const f32 *x, const f32 *y, SizeT d) {
return dot ? dot / sqrt(sqr_x * sqr_y) : 0.0f;
}

f32 HammingDistance_common(const u8 *x, const u8 *y, SizeT d) {
f32 result = 0;
for (SizeT i = 0; i < d; ++i) {
u8 xor_result = x[i] ^ y[i];
result += __builtin_popcount(xor_result);
}
return result;
}

#if defined(__AVX2__)

f32 HammingDistance_avx2(const u8 *x, const u8 *y, SizeT d) {
f32 result = 0;
SizeT pos = 0;
// 8 * 32 = 256
for (; pos + 32 < d; pos += 32) {
__m256i xor_result =
_mm256_xor_si256(_mm256_loadu_si256(reinterpret_cast<const __m256i *>(x)), _mm256_loadu_si256(reinterpret_cast<const __m256i *>(y)));
result += popcount_avx2(xor_result);
x += 32;
y += 32;
}
if (pos < d) {
result += HammingDistance_common(x, y, d - pos);
}
return result;
}

#endif // defined (__AVX2__)

#if defined(__SSE2__)

f32 HammingDistance_sse2(const u8 *x, const u8 *y, SizeT d) {
f32 result = 0;
SizeT pos = 0;
// 8 * 16 = 128
for (; pos + 16 < d; pos += 16) {
__m128i xor_result =
_mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(x)), _mm_loadu_si128(reinterpret_cast<const __m128i *>(y)));
result += popcount_sse2(xor_result);
x += 16;
y += 16;
}
if (pos < d) {
result += HammingDistance_common(x, y, d - pos);
}
return result;
}

#endif // defined (__SSE2__)

#if defined(__AVX2__)
inline f32 L2Distance_avx2_128(const f32 *vector1, const f32 *vector2, SizeT) {
__m256 diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2));
Expand Down
8 changes: 8 additions & 0 deletions src/common/simd/distance_simd_functions.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,20 @@ export f32 IPDistance_common(const f32 *x, const f32 *y, SizeT d);

export f32 CosineDistance_common(const f32 *x, const f32 *y, SizeT d);

export f32 HammingDistance_common(const u8 *x, const u8 *y, SizeT d);

#if defined(__AVX2__)
export f32 L2Distance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension);

export f32 IPDistance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension);

export f32 CosineDistance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension);

export f32 HammingDistance_avx2(const u8 *vector1, const u8 *vector2, SizeT dimension);
#endif

#if defined(__SSE2__)
export f32 HammingDistance_sse2(const u8 *vector1, const u8 *vector2, SizeT dimesion);
#endif

} // namespace infinity
79 changes: 58 additions & 21 deletions src/common/simd/simd_common_tools.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ export U8MaskPtr GetU8MasksForAVX2();

#ifdef __SSE__
// https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction/35270026#35270026
export inline float hsum_ps_sse1(__m128 v) { // v = [ D C | B A ]
__m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ]
__m128 sums = _mm_add_ps(v, shuf); // sums = [ D+C C+D | B+A A+B ]
shuf = _mm_movehl_ps(shuf, sums); // [ C D | D+C C+D ] // let the compiler avoid a mov by reusing shuf
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
export inline float hsum_ps_sse1(__m128 v) { // v = [ D C | B A ]
__m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ]
__m128 sums = _mm_add_ps(v, shuf); // sums = [ D+C C+D | B+A A+B ]
shuf = _mm_movehl_ps(shuf, sums); // [ C D | D+C C+D ] // let the compiler avoid a mov by reusing shuf
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
}
#endif

Expand Down Expand Up @@ -60,41 +60,37 @@ export inline float hsum256_ps_avx(__m256 v) {
#ifdef __SSE2__
// https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction/35270026#35270026
export int hsum_epi32_sse2(__m128i x) {
__m128i hi64 = _mm_shuffle_epi32(x, _MM_SHUFFLE(1, 0, 3, 2));
__m128i hi64 = _mm_shuffle_epi32(x, _MM_SHUFFLE(1, 0, 3, 2));
__m128i sum64 = _mm_add_epi32(hi64, x);
__m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2)); // Swap the low two elements
__m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2)); // Swap the low two elements
__m128i sum32 = _mm_add_epi32(sum64, hi32);
return _mm_cvtsi128_si32(sum32); // SSE2 movd
return _mm_cvtsi128_si32(sum32); // SSE2 movd
}
#endif

#ifdef __AVX__
// https://stackoverflow.com/questions/60108658/fastest-method-to-calculate-sum-of-all-packed-32-bit-integers-using-avx512-or-av/60109639#60109639
export int hsum_epi32_avx(__m128i x)
{
__m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a movdqa
export int hsum_epi32_avx(__m128i x) {
__m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a movdqa
__m128i sum64 = _mm_add_epi32(hi64, x);
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); // Swap the low two elements
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); // Swap the low two elements
__m128i sum32 = _mm_add_epi32(sum64, hi32);
return _mm_cvtsi128_si32(sum32); // movd
return _mm_cvtsi128_si32(sum32); // movd
}
#endif

#ifdef __AVX2__
// https://stackoverflow.com/questions/60108658/fastest-method-to-calculate-sum-of-all-packed-32-bit-integers-using-avx512-or-av/60109639#60109639
// only needs AVX2
export int hsum_8x32_avx2(__m256i v)
{
__m128i sum128 = _mm_add_epi32(
_mm256_castsi256_si128(v),
_mm256_extracti128_si256(v, 1)); // silly GCC uses a longer AXV512VL instruction if AVX512 is enabled :/
export int hsum_8x32_avx2(__m256i v) {
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(v),
_mm256_extracti128_si256(v, 1)); // silly GCC uses a longer AXV512VL instruction if AVX512 is enabled :/
return hsum_epi32_avx(sum128);
}
#endif

#ifdef __AVX512F__
export int hsum_epi32_avx512(__m512i v)
{
export int hsum_epi32_avx512(__m512i v) {
__m256i lo = _mm512_castsi512_si256(v);
__m256i hi = _mm512_extracti64x4_epi64(v, 1);
__m256i sum = _mm256_add_epi32(lo, hi);
Expand Down Expand Up @@ -154,4 +150,45 @@ export inline __m512i abs_sub_epu8_avx512(const __m512i a, const __m512i b) {
}
#endif

// https://github.com/WojciechMula/sse-popcount/blob/master/popcnt-avx2-harley-seal.cpp
#if defined(__AVX2__)
export inline int popcount_avx2(const __m256i v) {
const __m256i m1 = _mm256_set1_epi8(0x55);
const __m256i m2 = _mm256_set1_epi8(0x33);
const __m256i m4 = _mm256_set1_epi8(0x0F);

const __m256i t1 = _mm256_sub_epi8(v, (_mm256_srli_epi16(v, 1) & m1));
const __m256i t2 = _mm256_add_epi8(t1 & m2, (_mm256_srli_epi16(t1, 2) & m2));
const __m256i t3 = _mm256_add_epi8(t2, _mm256_srli_epi16(t2, 4)) & m4;
__m256i sad = _mm256_sad_epu8(t3, _mm256_setzero_si256());

__m128i sad_low = _mm256_extracti128_si256(sad, 0);
__m128i sad_high = _mm256_extracti128_si256(sad, 1);

sad_low = _mm_add_epi64(sad_low, sad_high);

int result = _mm_extract_epi64(sad_low, 0) + _mm_extract_epi64(sad_low, 1);
return result;
}
#endif // defined (__AVX2__)

// https://github.com/WojciechMula/sse-popcount/blob/master/popcnt-sse-harley-seal.cpp
#if defined(__SSE2__)
export inline int popcount_sse2(const __m128i x) {
const __m128i m1 = _mm_set1_epi8(0x55);
const __m128i m2 = _mm_set1_epi8(0x33);
const __m128i m4 = _mm_set1_epi8(0x0F);

const __m128i t1 = x;
const __m128i t2 = _mm_sub_epi8(t1, _mm_srli_epi16(t1, 1) & m1);
const __m128i t3 = _mm_add_epi8(t2 & m2, _mm_srli_epi16(t2, 2) & m2);
const __m128i t4 = _mm_add_epi8(t3, _mm_srli_epi16(t3, 4)) & m4;

__m128i sad = _mm_sad_epu8(t4, _mm_setzero_si128());

int result = _mm_extract_epi64(sad, 0) + _mm_extract_epi64(sad, 1);
return result;
}
#endif // defined (__SSE2__)

} // namespace infinity
1 change: 1 addition & 0 deletions src/common/simd/simd_functions.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export struct SIMD_FUNCTIONS {
F32DistanceFuncType L2Distance_func_ptr_ = GetL2DistanceFuncPtr();
F32DistanceFuncType IPDistance_func_ptr_ = GetIPDistanceFuncPtr();
F32DistanceFuncType CosineDistance_func_ptr_ = GetCosineDistanceFuncPtr();
U8HammingDistanceFuncType HammingDistance_func_ptr_ = GetHammingDistanceFuncPtr();

// HNSW F32
F32DistanceFuncType HNSW_F32L2_ptr_ = Get_HNSW_F32L2_ptr();
Expand Down
10 changes: 10 additions & 0 deletions src/common/simd/simd_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ F32DistanceFuncType GetCosineDistanceFuncPtr() {
return &CosineDistance_common;
}

U8HammingDistanceFuncType GetHammingDistanceFuncPtr() {
#ifdef __AVX2__
return &HammingDistance_avx2;
#endif
#ifdef __SSE2__
return &HammingDistance_sse2;
#endif
return &HammingDistance_common;
}

F32DistanceFuncType Get_HNSW_F32L2_16_ptr() {
#if defined(__AVX512F__)
if (IsAVX512Supported()) {
Expand Down
6 changes: 6 additions & 0 deletions src/common/simd/simd_init.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ export using F32DistanceFuncType = f32(*)(const f32 *, const f32 *, SizeT);
export using I8DistanceFuncType = i32(*)(const i8 *, const i8 *, SizeT);
export using I8CosDistanceFuncType = f32(*)(const i8 *, const i8 *, SizeT);
export using U8DistanceFuncType = i32(*)(const u8 *, const u8 *, SizeT);
//dimension in hamming distance is in bytes
export using U8HammingDistanceFuncType = f32(*)(const u8 *, const u8 *, SizeT);
export using U8CosDistanceFuncType = f32(*)(const u8 *, const u8 *, SizeT);
export using MaxSimF32BitIPFuncType = f32(*)(const f32 *, const u8 *, SizeT);
export using MaxSimI32BitIPFuncType = i32(*)(const i32 *, const u8 *, SizeT);
Expand All @@ -42,6 +44,10 @@ export using SearchTop1WithDisF32U32FuncType = void(*)(u32, u32, const f32 *, u3
export F32DistanceFuncType GetL2DistanceFuncPtr();
export F32DistanceFuncType GetIPDistanceFuncPtr();
export F32DistanceFuncType GetCosineDistanceFuncPtr();

// u32 distance functions
export U8HammingDistanceFuncType GetHammingDistanceFuncPtr();

// HNSW F32
export F32DistanceFuncType Get_HNSW_F32L2_ptr();
export F32DistanceFuncType Get_HNSW_F32L2_16_ptr();
Expand Down
28 changes: 21 additions & 7 deletions src/executor/operator/physical_scan/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ auto GetKnnExprForCalculation(const KnnExpression &src_knn_expr, const Embedding
const auto src_query_embedding_type = src_knn_expr.embedding_data_type_;
EmbeddingDataType new_query_embedding_type = EmbeddingDataType::kElemInvalid;
switch (column_embedding_type) {
case EmbeddingDataType::kElemBit:
case EmbeddingDataType::kElemUInt8:
case EmbeddingDataType::kElemInt8: {
// expect query embedding to be the same type
Expand All @@ -86,7 +87,6 @@ auto GetKnnExprForCalculation(const KnnExpression &src_knn_expr, const Embedding
// no need for alignment
break;
}
case EmbeddingDataType::kElemBit:
case EmbeddingDataType::kElemInt16:
case EmbeddingDataType::kElemInt32:
case EmbeddingDataType::kElemInt64: {
Expand Down Expand Up @@ -228,7 +228,9 @@ void PhysicalKnnScan::ExecuteInternalByColumnLogicalType(QueryContext *query_con
case EmbeddingDataType::kElemBFloat16: {
return ExecuteInternalByColumnDataType<t, BFloat16T>(query_context, knn_scan_operator_state);
}
case EmbeddingDataType::kElemBit:
case EmbeddingDataType::kElemBit: {
return ExecuteInternalByColumnDataType<t, u8>(query_context, knn_scan_operator_state);
}
case EmbeddingDataType::kElemInt16:
case EmbeddingDataType::kElemInt32:
case EmbeddingDataType::kElemInt64:
Expand Down Expand Up @@ -322,7 +324,16 @@ void PhysicalKnnScan::ExecuteInternalByColumnDataType(QueryContext *query_contex
UnrecoverableError(fmt::format("BUG: Query embedding data type: {} should be cast to Float before knn search!",
EmbeddingT::EmbeddingDataType2String(query_elem_type)));
}
case EmbeddingDataType::kElemBit:
case EmbeddingDataType::kElemBit: {
switch (query_dist_type) {
case KnnDistanceType::kHamming: {
return ExecuteDispatchHelper<t, ColumnDataT, u8, CompareMax, f32>::Execute(this, query_context, knn_scan_operator_state);
}
default: {
return knn_distance_error();
}
}
}
case EmbeddingDataType::kElemInt16:
case EmbeddingDataType::kElemInt32:
case EmbeddingDataType::kElemInt64:
Expand Down Expand Up @@ -414,8 +425,7 @@ void PhysicalKnnScan::PlanWithIndex(QueryContext *query_context) { // TODO: retu
RecoverableError(std::move(error_status));
}
// check index type
if (auto index_type = table_index_entry->index_base()->index_type_;
index_type != IndexType::kIVF and index_type != IndexType::kHnsw) {
if (auto index_type = table_index_entry->index_base()->index_type_; index_type != IndexType::kIVF and index_type != IndexType::kHnsw) {
LOG_ERROR("Invalid index type");
Status error_status = Status::InvalidIndexType("invalid index");
RecoverableError(std::move(error_status));
Expand Down Expand Up @@ -809,7 +819,12 @@ struct BruteForceBlockScan<LogicalType::kEmbedding, ColumnDataType, QueryDataTyp
}
target_ptr = buffer_ptr_for_cast.get();
}
merge_heap->Search(knn_query_ptr, target_ptr, embedding_dim, dist_func->dist_func_, row_count, segment_id, block_id, bitmask);
auto embedding_info = static_cast<EmbeddingInfo *>(column_vector.data_type()->type_info().get());
if (embedding_info->Type() == EmbeddingDataType::kElemBit) {
merge_heap->Search(knn_query_ptr, target_ptr, embedding_dim / 8, dist_func->dist_func_, row_count, segment_id, block_id, bitmask);
} else {
merge_heap->Search(knn_query_ptr, target_ptr, embedding_dim, dist_func->dist_func_, row_count, segment_id, block_id, bitmask);
}
}
};

Expand Down Expand Up @@ -880,5 +895,4 @@ void MultiVectorSearchOneLine(MergeKnn<QueryDataType, C, DistanceDataType> *merg
merge_heap->Search(0, &result_dist, &db_row_id, 1);
}


} // namespace infinity
14 changes: 12 additions & 2 deletions src/function/table/knn_scan_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ KnnDistance1<u8, f32>::KnnDistance1(KnnDistanceType dist_type) {
dist_func_ = &hnsw_u8ip_f32_wrapper;
break;
}
case KnnDistanceType::kHamming: {
dist_func_ = GetSIMD_FUNCTIONS().HammingDistance_func_ptr_;
break;
}
default: {
Status status = Status::NotSupport(fmt::format("KnnDistanceType: {} is not support.", (i32)dist_type));
RecoverableError(status);
Expand Down Expand Up @@ -166,6 +170,10 @@ KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData *shared_data, u32 cur
Init<i8, f32>();
break;
}
case EmbeddingDataType::kElemBit: {
Init<u8, f32>();
break;
}
default: {
Status status = Status::NotSupport(fmt::format("Query EmbeddingDataType: {} is not support.",
EmbeddingType::EmbeddingDataType2String(knn_scan_shared_data_->query_elem_type_)));
Expand All @@ -183,14 +191,16 @@ void KnnScanFunctionData::Init() {
}
case KnnDistanceType::kL2:
case KnnDistanceType::kHamming: {
auto merge_knn_max = MakeUnique<MergeKnn<QueryDataType, CompareMax, DistDataType>>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_);
auto merge_knn_max =
MakeUnique<MergeKnn<QueryDataType, CompareMax, DistDataType>>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_);
merge_knn_max->Begin();
merge_knn_base_ = std::move(merge_knn_max);
break;
}
case KnnDistanceType::kCosine:
case KnnDistanceType::kInnerProduct: {
auto merge_knn_min = MakeUnique<MergeKnn<QueryDataType, CompareMin, DistDataType>>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_);
auto merge_knn_min =
MakeUnique<MergeKnn<QueryDataType, CompareMin, DistDataType>>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_);
merge_knn_min->Begin();
merge_knn_base_ = std::move(merge_knn_min);
break;
Expand Down
4 changes: 2 additions & 2 deletions src/parser/expr/knn_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ KnnExpr::~KnnExpr() {
column_expr_ = nullptr;
}

if(opt_params_ != nullptr) {
for(auto* param_ptr: *opt_params_) {
if (opt_params_ != nullptr) {
for (auto *param_ptr : *opt_params_) {
delete param_ptr;
param_ptr = nullptr;
}
Expand Down
14 changes: 13 additions & 1 deletion test/sql/dml/insert/test_insert_embedding.slt
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,16 @@ SELECT * FROM sqllogic_test_insert_embedding;
0 [0.1,1.1,2.1,3.1,4.1,5.1,6.1,7.1,8.1,9.1,10.1,11.1,12.1,13.1,14.1,15.1]

statement ok
DROP TABLE sqllogic_test_insert_embedding;
DROP TABLE sqllogic_test_insert_embedding;

statement ok
CREATE TABLE sqllogic_test_insert_embedding (col1 INT, col2 EMBEDDING(BIT, 8));

query I
INSERT INTO sqllogic_test_insert_embedding VALUES(0, [0.0, 0.1, 0.2, 0.3, 0.3, 0.2, 0.1, 0.0]);
----

query II
SELECT * FROM sqllogic_test_insert_embedding;
----
0 [01111110]
Loading

0 comments on commit 9a17553

Please sign in to comment.