Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports Binary vector with Hamming distance. #2070

Merged
merged 7 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion src/common/simd/distance_simd_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

module;

#include <cmath>
#include "simd_common_intrin_include.h"
#include <cmath>

/*
#if defined(__x86_64__) && (defined(__clang_major__) && (__clang_major__ > 10))
Expand Down Expand Up @@ -76,6 +76,57 @@ f32 CosineDistance_common(const f32 *x, const f32 *y, SizeT d) {
return dot ? dot / sqrt(sqr_x * sqr_y) : 0.0f;
}

f32 HammingDistance_common(const u8 *x, const u8 *y, SizeT d) {
f32 result = 0;
for (SizeT i = 0; i < d; ++i) {
u8 xor_result = x[i] ^ y[i];
result += __builtin_popcount(xor_result);
}
return result;
}

#if defined(__AVX2__)

f32 HammingDistance_avx2(const u8 *x, const u8 *y, SizeT d) {
f32 result = 0;
SizeT pos = 0;
// 8 * 32 = 256
for (; pos + 32 < d; pos += 32) {
__m256i xor_result =
_mm256_xor_si256(_mm256_loadu_si256(reinterpret_cast<const __m256i *>(x)), _mm256_loadu_si256(reinterpret_cast<const __m256i *>(y)));
result += popcount_avx2(xor_result);
x += 32;
y += 32;
}
if (pos < d) {
result += HammingDistance_common(x, y, d - pos);
}
return result;
}

#endif // defined (__AVX2__)

#if defined(__SSE2__)

f32 HammingDistance_sse2(const u8 *x, const u8 *y, SizeT d) {
f32 result = 0;
SizeT pos = 0;
// 8 * 16 = 128
for (; pos + 16 < d; pos += 16) {
__m128i xor_result =
_mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i *>(x)), _mm_loadu_si128(reinterpret_cast<const __m128i *>(y)));
result += popcount_sse2(xor_result);
x += 16;
y += 16;
}
if (pos < d) {
result += HammingDistance_common(x, y, d - pos);
}
return result;
}

#endif // defined (__SSE2__)

#if defined(__AVX2__)
inline f32 L2Distance_avx2_128(const f32 *vector1, const f32 *vector2, SizeT) {
__m256 diff_1 = _mm256_sub_ps(_mm256_loadu_ps(vector1), _mm256_loadu_ps(vector2));
Expand Down
8 changes: 8 additions & 0 deletions src/common/simd/distance_simd_functions.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,20 @@ export f32 IPDistance_common(const f32 *x, const f32 *y, SizeT d);

export f32 CosineDistance_common(const f32 *x, const f32 *y, SizeT d);

export f32 HammingDistance_common(const u8 *x, const u8 *y, SizeT d);

#if defined(__AVX2__)
export f32 L2Distance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension);

export f32 IPDistance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension);

export f32 CosineDistance_avx2(const f32 *vector1, const f32 *vector2, SizeT dimension);

export f32 HammingDistance_avx2(const u8 *vector1, const u8 *vector2, SizeT dimension);
#endif

#if defined(__SSE2__)
export f32 HammingDistance_sse2(const u8 *vector1, const u8 *vector2, SizeT dimesion);
#endif

} // namespace infinity
79 changes: 58 additions & 21 deletions src/common/simd/simd_common_tools.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ export U8MaskPtr GetU8MasksForAVX2();

#ifdef __SSE__
// https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction/35270026#35270026
export inline float hsum_ps_sse1(__m128 v) { // v = [ D C | B A ]
__m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ]
__m128 sums = _mm_add_ps(v, shuf); // sums = [ D+C C+D | B+A A+B ]
shuf = _mm_movehl_ps(shuf, sums); // [ C D | D+C C+D ] // let the compiler avoid a mov by reusing shuf
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
export inline float hsum_ps_sse1(__m128 v) { // v = [ D C | B A ]
__m128 shuf = _mm_shuffle_ps(v, v, _MM_SHUFFLE(2, 3, 0, 1)); // [ C D | A B ]
__m128 sums = _mm_add_ps(v, shuf); // sums = [ D+C C+D | B+A A+B ]
shuf = _mm_movehl_ps(shuf, sums); // [ C D | D+C C+D ] // let the compiler avoid a mov by reusing shuf
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
}
#endif

Expand Down Expand Up @@ -60,41 +60,37 @@ export inline float hsum256_ps_avx(__m256 v) {
#ifdef __SSE2__
// https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction/35270026#35270026
export int hsum_epi32_sse2(__m128i x) {
__m128i hi64 = _mm_shuffle_epi32(x, _MM_SHUFFLE(1, 0, 3, 2));
__m128i hi64 = _mm_shuffle_epi32(x, _MM_SHUFFLE(1, 0, 3, 2));
__m128i sum64 = _mm_add_epi32(hi64, x);
__m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2)); // Swap the low two elements
__m128i hi32 = _mm_shufflelo_epi16(sum64, _MM_SHUFFLE(1, 0, 3, 2)); // Swap the low two elements
__m128i sum32 = _mm_add_epi32(sum64, hi32);
return _mm_cvtsi128_si32(sum32); // SSE2 movd
return _mm_cvtsi128_si32(sum32); // SSE2 movd
}
#endif

#ifdef __AVX__
// https://stackoverflow.com/questions/60108658/fastest-method-to-calculate-sum-of-all-packed-32-bit-integers-using-avx512-or-av/60109639#60109639
export int hsum_epi32_avx(__m128i x)
{
__m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a movdqa
export int hsum_epi32_avx(__m128i x) {
__m128i hi64 = _mm_unpackhi_epi64(x, x); // 3-operand non-destructive AVX lets us save a byte without needing a movdqa
__m128i sum64 = _mm_add_epi32(hi64, x);
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); // Swap the low two elements
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); // Swap the low two elements
__m128i sum32 = _mm_add_epi32(sum64, hi32);
return _mm_cvtsi128_si32(sum32); // movd
return _mm_cvtsi128_si32(sum32); // movd
}
#endif

#ifdef __AVX2__
// https://stackoverflow.com/questions/60108658/fastest-method-to-calculate-sum-of-all-packed-32-bit-integers-using-avx512-or-av/60109639#60109639
// only needs AVX2
export int hsum_8x32_avx2(__m256i v)
{
__m128i sum128 = _mm_add_epi32(
_mm256_castsi256_si128(v),
_mm256_extracti128_si256(v, 1)); // silly GCC uses a longer AXV512VL instruction if AVX512 is enabled :/
export int hsum_8x32_avx2(__m256i v) {
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(v),
_mm256_extracti128_si256(v, 1)); // silly GCC uses a longer AXV512VL instruction if AVX512 is enabled :/
return hsum_epi32_avx(sum128);
}
#endif

#ifdef __AVX512F__
export int hsum_epi32_avx512(__m512i v)
{
export int hsum_epi32_avx512(__m512i v) {
__m256i lo = _mm512_castsi512_si256(v);
__m256i hi = _mm512_extracti64x4_epi64(v, 1);
__m256i sum = _mm256_add_epi32(lo, hi);
Expand Down Expand Up @@ -154,4 +150,45 @@ export inline __m512i abs_sub_epu8_avx512(const __m512i a, const __m512i b) {
}
#endif

// https://github.com/WojciechMula/sse-popcount/blob/master/popcnt-avx2-harley-seal.cpp
#if defined(__AVX2__)
export inline int popcount_avx2(const __m256i v) {
const __m256i m1 = _mm256_set1_epi8(0x55);
const __m256i m2 = _mm256_set1_epi8(0x33);
const __m256i m4 = _mm256_set1_epi8(0x0F);

const __m256i t1 = _mm256_sub_epi8(v, (_mm256_srli_epi16(v, 1) & m1));
const __m256i t2 = _mm256_add_epi8(t1 & m2, (_mm256_srli_epi16(t1, 2) & m2));
const __m256i t3 = _mm256_add_epi8(t2, _mm256_srli_epi16(t2, 4)) & m4;
__m256i sad = _mm256_sad_epu8(t3, _mm256_setzero_si256());

__m128i sad_low = _mm256_extracti128_si256(sad, 0);
__m128i sad_high = _mm256_extracti128_si256(sad, 1);

sad_low = _mm_add_epi64(sad_low, sad_high);

int result = _mm_extract_epi64(sad_low, 0) + _mm_extract_epi64(sad_low, 1);
return result;
}
#endif // defined (__AVX2__)

// https://github.com/WojciechMula/sse-popcount/blob/master/popcnt-sse-harley-seal.cpp
#if defined(__SSE2__)
export inline int popcount_sse2(const __m128i x) {
const __m128i m1 = _mm_set1_epi8(0x55);
const __m128i m2 = _mm_set1_epi8(0x33);
const __m128i m4 = _mm_set1_epi8(0x0F);

const __m128i t1 = x;
const __m128i t2 = _mm_sub_epi8(t1, _mm_srli_epi16(t1, 1) & m1);
const __m128i t3 = _mm_add_epi8(t2 & m2, _mm_srli_epi16(t2, 2) & m2);
const __m128i t4 = _mm_add_epi8(t3, _mm_srli_epi16(t3, 4)) & m4;

__m128i sad = _mm_sad_epu8(t4, _mm_setzero_si128());

int result = _mm_extract_epi64(sad, 0) + _mm_extract_epi64(sad, 1);
return result;
}
#endif // defined (__SSE2__)

} // namespace infinity
1 change: 1 addition & 0 deletions src/common/simd/simd_functions.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export struct SIMD_FUNCTIONS {
F32DistanceFuncType L2Distance_func_ptr_ = GetL2DistanceFuncPtr();
F32DistanceFuncType IPDistance_func_ptr_ = GetIPDistanceFuncPtr();
F32DistanceFuncType CosineDistance_func_ptr_ = GetCosineDistanceFuncPtr();
U8HammingDistanceFuncType HammingDistance_func_ptr_ = GetHammingDistanceFuncPtr();

// HNSW F32
F32DistanceFuncType HNSW_F32L2_ptr_ = Get_HNSW_F32L2_ptr();
Expand Down
10 changes: 10 additions & 0 deletions src/common/simd/simd_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ F32DistanceFuncType GetCosineDistanceFuncPtr() {
return &CosineDistance_common;
}

U8HammingDistanceFuncType GetHammingDistanceFuncPtr() {
#ifdef __AVX2__
return &HammingDistance_avx2;
#endif
#ifdef __SSE2__
return &HammingDistance_sse2;
#endif
return &HammingDistance_common;
}

F32DistanceFuncType Get_HNSW_F32L2_16_ptr() {
#if defined(__AVX512F__)
if (IsAVX512Supported()) {
Expand Down
6 changes: 6 additions & 0 deletions src/common/simd/simd_init.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ export using F32DistanceFuncType = f32(*)(const f32 *, const f32 *, SizeT);
export using I8DistanceFuncType = i32(*)(const i8 *, const i8 *, SizeT);
export using I8CosDistanceFuncType = f32(*)(const i8 *, const i8 *, SizeT);
export using U8DistanceFuncType = i32(*)(const u8 *, const u8 *, SizeT);
//dimension in hamming distance is in bytes
export using U8HammingDistanceFuncType = f32(*)(const u8 *, const u8 *, SizeT);
export using U8CosDistanceFuncType = f32(*)(const u8 *, const u8 *, SizeT);
export using MaxSimF32BitIPFuncType = f32(*)(const f32 *, const u8 *, SizeT);
export using MaxSimI32BitIPFuncType = i32(*)(const i32 *, const u8 *, SizeT);
Expand All @@ -42,6 +44,10 @@ export using SearchTop1WithDisF32U32FuncType = void(*)(u32, u32, const f32 *, u3
export F32DistanceFuncType GetL2DistanceFuncPtr();
export F32DistanceFuncType GetIPDistanceFuncPtr();
export F32DistanceFuncType GetCosineDistanceFuncPtr();

// u32 distance functions
export U8HammingDistanceFuncType GetHammingDistanceFuncPtr();

// HNSW F32
export F32DistanceFuncType Get_HNSW_F32L2_ptr();
export F32DistanceFuncType Get_HNSW_F32L2_16_ptr();
Expand Down
28 changes: 21 additions & 7 deletions src/executor/operator/physical_scan/physical_knn_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ auto GetKnnExprForCalculation(const KnnExpression &src_knn_expr, const Embedding
const auto src_query_embedding_type = src_knn_expr.embedding_data_type_;
EmbeddingDataType new_query_embedding_type = EmbeddingDataType::kElemInvalid;
switch (column_embedding_type) {
case EmbeddingDataType::kElemBit:
case EmbeddingDataType::kElemUInt8:
case EmbeddingDataType::kElemInt8: {
// expect query embedding to be the same type
Expand All @@ -86,7 +87,6 @@ auto GetKnnExprForCalculation(const KnnExpression &src_knn_expr, const Embedding
// no need for alignment
break;
}
case EmbeddingDataType::kElemBit:
case EmbeddingDataType::kElemInt16:
case EmbeddingDataType::kElemInt32:
case EmbeddingDataType::kElemInt64: {
Expand Down Expand Up @@ -228,7 +228,9 @@ void PhysicalKnnScan::ExecuteInternalByColumnLogicalType(QueryContext *query_con
case EmbeddingDataType::kElemBFloat16: {
return ExecuteInternalByColumnDataType<t, BFloat16T>(query_context, knn_scan_operator_state);
}
case EmbeddingDataType::kElemBit:
case EmbeddingDataType::kElemBit: {
return ExecuteInternalByColumnDataType<t, u8>(query_context, knn_scan_operator_state);
}
case EmbeddingDataType::kElemInt16:
case EmbeddingDataType::kElemInt32:
case EmbeddingDataType::kElemInt64:
Expand Down Expand Up @@ -322,7 +324,16 @@ void PhysicalKnnScan::ExecuteInternalByColumnDataType(QueryContext *query_contex
UnrecoverableError(fmt::format("BUG: Query embedding data type: {} should be cast to Float before knn search!",
EmbeddingT::EmbeddingDataType2String(query_elem_type)));
}
case EmbeddingDataType::kElemBit:
case EmbeddingDataType::kElemBit: {
switch (query_dist_type) {
case KnnDistanceType::kHamming: {
return ExecuteDispatchHelper<t, ColumnDataT, u8, CompareMax, f32>::Execute(this, query_context, knn_scan_operator_state);
}
default: {
return knn_distance_error();
}
}
}
case EmbeddingDataType::kElemInt16:
case EmbeddingDataType::kElemInt32:
case EmbeddingDataType::kElemInt64:
Expand Down Expand Up @@ -414,8 +425,7 @@ void PhysicalKnnScan::PlanWithIndex(QueryContext *query_context) { // TODO: retu
RecoverableError(std::move(error_status));
}
// check index type
if (auto index_type = table_index_entry->index_base()->index_type_;
index_type != IndexType::kIVF and index_type != IndexType::kHnsw) {
if (auto index_type = table_index_entry->index_base()->index_type_; index_type != IndexType::kIVF and index_type != IndexType::kHnsw) {
LOG_ERROR("Invalid index type");
Status error_status = Status::InvalidIndexType("invalid index");
RecoverableError(std::move(error_status));
Expand Down Expand Up @@ -809,7 +819,12 @@ struct BruteForceBlockScan<LogicalType::kEmbedding, ColumnDataType, QueryDataTyp
}
target_ptr = buffer_ptr_for_cast.get();
}
merge_heap->Search(knn_query_ptr, target_ptr, embedding_dim, dist_func->dist_func_, row_count, segment_id, block_id, bitmask);
auto embedding_info = static_cast<EmbeddingInfo *>(column_vector.data_type()->type_info().get());
if (embedding_info->Type() == EmbeddingDataType::kElemBit) {
merge_heap->Search(knn_query_ptr, target_ptr, embedding_dim / 8, dist_func->dist_func_, row_count, segment_id, block_id, bitmask);
} else {
merge_heap->Search(knn_query_ptr, target_ptr, embedding_dim, dist_func->dist_func_, row_count, segment_id, block_id, bitmask);
}
}
};

Expand Down Expand Up @@ -880,5 +895,4 @@ void MultiVectorSearchOneLine(MergeKnn<QueryDataType, C, DistanceDataType> *merg
merge_heap->Search(0, &result_dist, &db_row_id, 1);
}


} // namespace infinity
14 changes: 12 additions & 2 deletions src/function/table/knn_scan_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ KnnDistance1<u8, f32>::KnnDistance1(KnnDistanceType dist_type) {
dist_func_ = &hnsw_u8ip_f32_wrapper;
break;
}
case KnnDistanceType::kHamming: {
dist_func_ = GetSIMD_FUNCTIONS().HammingDistance_func_ptr_;
break;
}
default: {
Status status = Status::NotSupport(fmt::format("KnnDistanceType: {} is not support.", (i32)dist_type));
RecoverableError(status);
Expand Down Expand Up @@ -161,6 +165,10 @@ KnnScanFunctionData::KnnScanFunctionData(KnnScanSharedData *shared_data, u32 cur
Init<i8, f32>();
break;
}
case EmbeddingDataType::kElemBit: {
Init<u8, f32>();
break;
}
default: {
Status status = Status::NotSupport(fmt::format("Query EmbeddingDataType: {} is not support.",
EmbeddingType::EmbeddingDataType2String(knn_scan_shared_data_->query_elem_type_)));
Expand All @@ -178,14 +186,16 @@ void KnnScanFunctionData::Init() {
}
case KnnDistanceType::kL2:
case KnnDistanceType::kHamming: {
auto merge_knn_max = MakeUnique<MergeKnn<QueryDataType, CompareMax, DistDataType>>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_);
auto merge_knn_max =
MakeUnique<MergeKnn<QueryDataType, CompareMax, DistDataType>>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_);
merge_knn_max->Begin();
merge_knn_base_ = std::move(merge_knn_max);
break;
}
case KnnDistanceType::kCosine:
case KnnDistanceType::kInnerProduct: {
auto merge_knn_min = MakeUnique<MergeKnn<QueryDataType, CompareMin, DistDataType>>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_);
auto merge_knn_min =
MakeUnique<MergeKnn<QueryDataType, CompareMin, DistDataType>>(knn_scan_shared_data_->query_count_, knn_scan_shared_data_->topk_);
merge_knn_min->Begin();
merge_knn_base_ = std::move(merge_knn_min);
break;
Expand Down
4 changes: 2 additions & 2 deletions src/parser/expr/knn_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ KnnExpr::~KnnExpr() {
column_expr_ = nullptr;
}

if(opt_params_ != nullptr) {
for(auto* param_ptr: *opt_params_) {
if (opt_params_ != nullptr) {
for (auto *param_ptr : *opt_params_) {
delete param_ptr;
param_ptr = nullptr;
}
Expand Down
14 changes: 13 additions & 1 deletion test/sql/dml/insert/test_insert_embedding.slt
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,16 @@ SELECT * FROM sqllogic_test_insert_embedding;
0 [0.1,1.1,2.1,3.1,4.1,5.1,6.1,7.1,8.1,9.1,10.1,11.1,12.1,13.1,14.1,15.1]

statement ok
DROP TABLE sqllogic_test_insert_embedding;
DROP TABLE sqllogic_test_insert_embedding;

statement ok
CREATE TABLE sqllogic_test_insert_embedding (col1 INT, col2 EMBEDDING(BIT, 8));

query I
INSERT INTO sqllogic_test_insert_embedding VALUES(0, [0.0, 0.1, 0.2, 0.3, 0.3, 0.2, 0.1, 0.0]);
----

query II
SELECT * FROM sqllogic_test_insert_embedding;
----
0 [01111110]
Loading
Loading