Skip to content

Commit

Permalink
Fix AVX2 distance function and build with AVX2 if available on host m…
Browse files Browse the repository at this point in the history
…achine (#524)
  • Loading branch information
jparismorgan authored Oct 17, 2024
1 parent 555b2e9 commit fd1bafb
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 149 deletions.
8 changes: 8 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ if (NOT $ENV{CIBUILDWHEEL} EQUAL 1 AND NOT $ENV{CONDA_BUILD} EQUAL 1)
endif()
endif()

# AVX2 flag
include(CheckAVX2Support)
CheckAVX2Support()
if (COMPILER_SUPPORTS_AVX2)
add_compile_options(${COMPILER_AVX2_FLAG} -mfma)
add_definitions(-DAVX2_ENABLED)
endif()

# Default to Release build
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
message(STATUS "No build type selected, default to Release")
Expand Down
72 changes: 72 additions & 0 deletions src/cmake/Modules/CheckAVX2Support.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# CheckAVX2Support.cmake
#
#
# The MIT License
#
# Copyright (c) 2018-2021 TileDB, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# This file defines a function to detect toolchain support for AVX2.
#

include(CheckCXXSourceRuns)
include(CMakePushCheckState)

#
# Determines if AVX2 is available.
#
# This function sets two variables in the cache:
# COMPILER_SUPPORTS_AVX2 - Set to true if the compiler supports AVX2.
# COMPILER_AVX2_FLAG - Set to the appropriate flag to enable AVX2.
#
function (CheckAVX2Support)
# If defined to a false value other than "", return without checking for avx2 support
if (DEFINED COMPILER_SUPPORTS_AVX2 AND
NOT COMPILER_SUPPORTS_AVX2 STREQUAL "" AND
NOT COMPILER_SUPPORTS_AVX2)
message("AVX2 compiler support disabled by COMPILER_SUPPORTS_AVX2=${COMPILER_SUPPORTS_AVX2}")
return()
endif()

if (MSVC)
set(COMPILER_AVX2_FLAG "/arch:AVX2" CACHE STRING "Compiler flag for AVX2 support.")
else()
set(COMPILER_AVX2_FLAG "-mavx2" CACHE STRING "Compiler flag for AVX2 support.")
endif()

cmake_push_check_state()
set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${COMPILER_AVX2_FLAG}")
check_cxx_source_runs("
#include <immintrin.h>
int main() {
__m256i packed = _mm256_set_epi32(-1, -2, -3, -4, -5, -6, -7, -8);
__m256i absolute_values = _mm256_abs_epi32(packed);
return 0;
}"
COMPILER_SUPPORTS_AVX2
)
cmake_pop_check_state()
if (COMPILER_SUPPORTS_AVX2)
message(STATUS "AVX2 support detected.")
else()
message(STATUS "AVX2 support not detected.")
endif()
endfunction()
90 changes: 64 additions & 26 deletions src/include/detail/scoring/inner_product_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,38 @@ inline float avx2_inner_product(const V& a, const W& b) {

template <feature_vector V, feature_vector W>
requires std::same_as<typename V::value_type, float> &&
std::same_as<typename W::value_type, uint8_t>
(std::same_as<typename W::value_type, uint8_t> ||
std::same_as<typename W::value_type, int8_t>)
inline float avx2_inner_product(const V& a, const W& b) {
// @todo Align on 256 bit boundaries
const size_t start = 0;
const size_t size_a = size(a);
const size_t stop = size_a - (size_a % 8);

const float* a_ptr = a.data();
const uint8_t* b_ptr = b.data();
// Can be uint8_t* or int8_t*
const auto* b_ptr = b.data();

__m256 vec_sum = _mm256_setzero_ps();

for (size_t i = start; i < stop; i += 8) {
// Load 8 floats
__m256 a_floats = _mm256_loadu_ps(a_ptr + i + 0);
__m256 a_floats = _mm256_loadu_ps(a_ptr + i);

// Load 8 bytes
// Load 8 bytes (uint8_t or int8_t)
__m128i vec_b = _mm_loadu_si64((__m64*)(b_ptr + i));

// Zero extend 8bit to 32bit ints
__m256i b_ints = _mm256_cvtepu8_epi32(vec_b);

// Convert signed integers to floats
// Conditionally convert based on the type of W::value_type
__m256i b_ints;
if constexpr (std::same_as<typename W::value_type, uint8_t>) {
// Zero extend uint8_t to int32_t
b_ints = _mm256_cvtepu8_epi32(vec_b);
} else if constexpr (std::same_as<typename W::value_type, int8_t>) {
// Sign extend int8_t to int32_t
b_ints = _mm256_cvtepi8_epi32(vec_b);
}

// Convert the 32-bit integers to floats
__m256 b_floats = _mm256_cvtepi32_ps(b_ints);

// Multiply and accumulate
Expand Down Expand Up @@ -139,30 +148,39 @@ inline float avx2_inner_product(const V& a, const W& b) {
}

template <feature_vector V, feature_vector W>
requires std::same_as<typename V::value_type, uint8_t> &&
std::same_as<typename W::value_type, float>
requires(std::same_as<typename V::value_type, uint8_t> ||
std::same_as<typename V::value_type, int8_t>) &&
std::same_as<typename W::value_type, float>
inline float avx2_inner_product(const V& a, const W& b) {
// @todo Align on 256 bit boundaries
const size_t start = 0;
const size_t size_a = size(a);
const size_t stop = size_a - (size_a % 8);

const uint8_t* a_ptr = a.data();
// Can be uint8_t* or int8_t*
const auto* a_ptr = a.data();
const float* b_ptr = b.data();

__m256 vec_sum = _mm256_setzero_ps();

for (size_t i = start; i < stop; i += 8) {
// Load 8 bytes == 64 bits -- zeros out top 8 bytes
// Load 8 bytes (either uint8_t or int8_t)
__m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i));

// Load 8 floats
__m256 b_floats = _mm256_loadu_ps(b_ptr + i + 0);

// Zero extend 8bit to 32bit ints
__m256i a_ints = _mm256_cvtepu8_epi32(vec_a);

// Convert signed integers to floats
// Extend 8 bit to 32 bit ints
__m256i a_ints;
if constexpr (std::same_as<typename V::value_type, uint8_t>) {
// Zero extend uint8_t to int32_t
a_ints = _mm256_cvtepu8_epi32(vec_a);
} else if constexpr (std::same_as<typename V::value_type, int8_t>) {
// Sign extend int8_t to int32_t
a_ints = _mm256_cvtepi8_epi32(vec_a);
}

// Convert the 32-bit integers to floats
__m256 a_floats = _mm256_cvtepi32_ps(a_ints);

// Multiply and accumulate
Expand Down Expand Up @@ -191,29 +209,49 @@ inline float avx2_inner_product(const V& a, const W& b) {
}

template <feature_vector V, feature_vector W>
requires std::same_as<typename V::value_type, uint8_t> &&
std::same_as<typename W::value_type, uint8_t>
requires(std::same_as<typename V::value_type, uint8_t> ||
std::same_as<typename V::value_type, int8_t>) &&
(std::same_as<typename W::value_type, uint8_t> ||
std::same_as<typename W::value_type, int8_t>)
inline float avx2_inner_product(const V& a, const W& b) {
// @todo Align on 256 bit boundaries
const size_t start = 0;
const size_t size_a = size(a);
const size_t stop = size_a - (size_a % 8);

const uint8_t* a_ptr = a.data();
const uint8_t* b_ptr = b.data();
// Can be either uint8_t* or int8_t*
const auto* a_ptr = a.data();
// Can be either uint8_t* or int8_t*
const auto* b_ptr = b.data();

__m256 vec_sum = _mm256_setzero_ps();

for (size_t i = start; i < stop; i += 8) {
// Load 8 bytes == 64 bits -- zeros out top 8 bytes
// Load 8 bytes (uint8_t or int8_t) from both vectors
__m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i));
__m128i vec_b = _mm_loadu_si64((__m64*)(b_ptr + i));

// Zero extend 8bit to 32bit ints
__m256i a_ints = _mm256_cvtepu8_epi32(vec_a);
__m256i b_ints = _mm256_cvtepu8_epi32(vec_b);

// Convert signed integers to floats
// Extend 8 bit to 32 bit ints
__m256i a_ints;
if constexpr (std::same_as<typename V::value_type, uint8_t>) {
// Zero extend uint8_t to int32_t
a_ints = _mm256_cvtepu8_epi32(vec_a);
} else if constexpr (std::same_as<typename V::value_type, int8_t>) {
// Sign extend int8_t to int32_t
a_ints = _mm256_cvtepi8_epi32(vec_a);
}

// Conditionally convert based on the type of W::value_type
__m256i b_ints;
if constexpr (std::same_as<typename W::value_type, uint8_t>) {
// Zero extend uint8_t to int32_t
b_ints = _mm256_cvtepu8_epi32(vec_b);
} else if constexpr (std::same_as<typename W::value_type, int8_t>) {
// Sign extend int8_t to int32_t
b_ints = _mm256_cvtepi8_epi32(vec_b);
}

// Convert the 32-bit integers to floats for both vectors
__m256 a_floats = _mm256_cvtepi32_ps(a_ints);
__m256 b_floats = _mm256_cvtepi32_ps(b_ints);

Expand Down
21 changes: 13 additions & 8 deletions src/include/detail/scoring/l2_distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ inline float naive_sum_of_squares(const V& a, const W& b) {
}

/**
* Compute l2 distance between vector of float and vector of uint8_t
* Compute l2 distance between vector of float and vector of uint8_t or int8_t
*/
template <feature_vector V, feature_vector W>
requires std::same_as<typename V::value_type, float> &&
std::same_as<typename W::value_type, uint8_t>
(std::same_as<typename W::value_type, uint8_t> ||
std::same_as<typename W::value_type, int8_t>)
inline float naive_sum_of_squares(const V& a, const W& b) {
size_t size_a = size(a);
float sum = 0.0;
Expand All @@ -94,11 +95,12 @@ inline float naive_sum_of_squares(const V& a, const W& b) {
}

/**
* Compute l2 distance between vector of uint8_t and vector of float
* Compute l2 distance between vector of uint8_t or int8_t and vector of float
*/
template <feature_vector V, feature_vector W>
requires std::same_as<typename V::value_type, uint8_t> &&
std::same_as<typename W::value_type, float>
requires(std::same_as<typename V::value_type, uint8_t> ||
std::same_as<typename V::value_type, int8_t>) &&
std::same_as<typename W::value_type, float>
inline float naive_sum_of_squares(const V& a, const W& b) {
size_t size_a = size(a);
float sum = 0.0;
Expand All @@ -110,11 +112,14 @@ inline float naive_sum_of_squares(const V& a, const W& b) {
}

/**
* Compute l2 distance between vector of uint8_t and vector of uint8_t
* Compute l2 distance between vector of uint8_t or int8_t and vector of uint8_t
* or int8_t
*/
template <feature_vector V, feature_vector W>
requires std::same_as<typename V::value_type, uint8_t> &&
std::same_as<typename W::value_type, uint8_t>
requires(std::same_as<typename V::value_type, uint8_t> ||
std::same_as<typename V::value_type, int8_t>) &&
(std::same_as<typename W::value_type, uint8_t> ||
std::same_as<typename W::value_type, int8_t>)
inline float naive_sum_of_squares(const V& a, const W& b) {
size_t size_a = size(a);
float sum = 0.0;
Expand Down
Loading

0 comments on commit fd1bafb

Please sign in to comment.