From fd1bafb4b3f67c42b9ed25c70a5dc3056e0b3b65 Mon Sep 17 00:00:00 2001 From: Paris Morgan Date: Thu, 17 Oct 2024 09:00:47 -0700 Subject: [PATCH] Fix AVX2 distance function and build with AVX2 if available on host machine (#524) --- src/CMakeLists.txt | 8 + src/cmake/Modules/CheckAVX2Support.cmake | 72 +++++++++ .../detail/scoring/inner_product_avx.h | 90 ++++++++--- src/include/detail/scoring/l2_distance.h | 21 ++- src/include/detail/scoring/l2_distance_avx.h | 116 ++++++++------ .../test/unit_inner_product_distance.cc | 150 +++++++++++------- src/include/test/unit_l2_distance.cc | 93 ++++++++++- src/include/test/unit_scoring.cc | 12 +- 8 files changed, 413 insertions(+), 149 deletions(-) create mode 100644 src/cmake/Modules/CheckAVX2Support.cmake diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5e3e1d1a6..8918400c7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -135,6 +135,14 @@ if (NOT $ENV{CIBUILDWHEEL} EQUAL 1 AND NOT $ENV{CONDA_BUILD} EQUAL 1) endif() endif() +# AVX2 flag +include(CheckAVX2Support) +CheckAVX2Support() +if (COMPILER_SUPPORTS_AVX2) + add_compile_options(${COMPILER_AVX2_FLAG} -mfma) + add_definitions(-DAVX2_ENABLED) +endif() + # Default to Release build if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) message(STATUS "No build type selected, default to Release") diff --git a/src/cmake/Modules/CheckAVX2Support.cmake b/src/cmake/Modules/CheckAVX2Support.cmake new file mode 100644 index 000000000..2954b2f66 --- /dev/null +++ b/src/cmake/Modules/CheckAVX2Support.cmake @@ -0,0 +1,72 @@ +# +# CheckAVX2Support.cmake +# +# +# The MIT License +# +# Copyright (c) 2018-2021 TileDB, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This file defines a function to detect toolchain support for AVX2. +# + +include(CheckCXXSourceRuns) +include(CMakePushCheckState) + +# +# Determines if AVX2 is available. +# +# This function sets two variables in the cache: +# COMPILER_SUPPORTS_AVX2 - Set to true if the compiler supports AVX2. +# COMPILER_AVX2_FLAG - Set to the appropriate flag to enable AVX2. +# +function (CheckAVX2Support) + # If defined to a false value other than "", return without checking for avx2 support + if (DEFINED COMPILER_SUPPORTS_AVX2 AND + NOT COMPILER_SUPPORTS_AVX2 STREQUAL "" AND + NOT COMPILER_SUPPORTS_AVX2) + message("AVX2 compiler support disabled by COMPILER_SUPPORTS_AVX2=${COMPILER_SUPPORTS_AVX2}") + return() + endif() + + if (MSVC) + set(COMPILER_AVX2_FLAG "/arch:AVX2" CACHE STRING "Compiler flag for AVX2 support.") + else() + set(COMPILER_AVX2_FLAG "-mavx2" CACHE STRING "Compiler flag for AVX2 support.") + endif() + + cmake_push_check_state() + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} ${COMPILER_AVX2_FLAG}") + check_cxx_source_runs(" + #include + int main() { + __m256i packed = _mm256_set_epi32(-1, -2, -3, -4, -5, -6, -7, -8); + __m256i absolute_values = _mm256_abs_epi32(packed); + return 0; + }" + COMPILER_SUPPORTS_AVX2 + ) + cmake_pop_check_state() + if (COMPILER_SUPPORTS_AVX2) + message(STATUS "AVX2 support detected.") + else() + message(STATUS "AVX2 support not detected.") + endif() +endfunction() diff --git a/src/include/detail/scoring/inner_product_avx.h b/src/include/detail/scoring/inner_product_avx.h index 189f8a4b3..e5af53581 100644 --- a/src/include/detail/scoring/inner_product_avx.h +++ b/src/include/detail/scoring/inner_product_avx.h @@ -88,7 +88,8 @@ inline float avx2_inner_product(const V& a, const W& b) { template requires std::same_as && - std::same_as + (std::same_as || + std::same_as) inline float avx2_inner_product(const V& a, const W& b) { // @todo Align on 256 bit boundaries const size_t start = 0; @@ -96,21 +97,29 @@ inline float avx2_inner_product(const V& a, const W& b) { const size_t stop = size_a - (size_a % 8); const float* a_ptr = a.data(); - const uint8_t* b_ptr = b.data(); + // Can be uint8_t* or int8_t* + const auto* b_ptr = b.data(); __m256 vec_sum = _mm256_setzero_ps(); for (size_t i = start; i < stop; i += 8) { // Load 8 floats - __m256 a_floats = _mm256_loadu_ps(a_ptr + i + 0); + __m256 a_floats = _mm256_loadu_ps(a_ptr + i); - // Load 8 bytes + // Load 8 bytes (uint8_t or int8_t) __m128i vec_b = _mm_loadu_si64((__m64*)(b_ptr + i)); - // Zero extend 8bit to 32bit ints - __m256i b_ints = _mm256_cvtepu8_epi32(vec_b); - - // Convert signed integers to floats + // Conditionally convert based on the type of W::value_type + __m256i b_ints; + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + b_ints = _mm256_cvtepu8_epi32(vec_b); + } else if constexpr (std::same_as) { + // Sign extend int8_t to int32_t + b_ints = _mm256_cvtepi8_epi32(vec_b); + } + + // Convert the 32-bit integers to floats __m256 b_floats = _mm256_cvtepi32_ps(b_ints); // Multiply and accumulate @@ -139,30 +148,39 @@ inline float avx2_inner_product(const V& a, const W& b) { } template - requires std::same_as && - std::same_as + requires(std::same_as || + std::same_as) && + std::same_as inline float avx2_inner_product(const V& a, const W& b) { // @todo Align on 256 bit boundaries const size_t start = 0; const size_t size_a = size(a); const size_t stop = size_a - (size_a % 8); - const uint8_t* a_ptr = a.data(); + // Can be uint8_t* or int8_t* + const auto* a_ptr = a.data(); const float* b_ptr = b.data(); __m256 vec_sum = _mm256_setzero_ps(); for (size_t i = start; i < stop; i += 8) { - // Load 8 bytes == 64 bits -- zeros out top 8 bytes + // Load 8 bytes (either uint8_t or int8_t) __m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i)); // Load 8 floats __m256 b_floats = _mm256_loadu_ps(b_ptr + i + 0); - // Zero extend 8bit to 32bit ints - __m256i a_ints = _mm256_cvtepu8_epi32(vec_a); - - // Convert signed integers to floats + // Extend 8 bit to 32 bit ints + __m256i a_ints; + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + a_ints = _mm256_cvtepu8_epi32(vec_a); + } else if constexpr (std::same_as) { + // Sign extend int8_t to int32_t + a_ints = _mm256_cvtepi8_epi32(vec_a); + } + + // Convert the 32-bit integers to floats __m256 a_floats = _mm256_cvtepi32_ps(a_ints); // Multiply and accumulate @@ -191,29 +209,49 @@ inline float avx2_inner_product(const V& a, const W& b) { } template - requires std::same_as && - std::same_as + requires(std::same_as || + std::same_as) && + (std::same_as || + std::same_as) inline float avx2_inner_product(const V& a, const W& b) { // @todo Align on 256 bit boundaries const size_t start = 0; const size_t size_a = size(a); const size_t stop = size_a - (size_a % 8); - const uint8_t* a_ptr = a.data(); - const uint8_t* b_ptr = b.data(); + // Can be either uint8_t* or int8_t* + const auto* a_ptr = a.data(); + // Can be either uint8_t* or int8_t* + const auto* b_ptr = b.data(); __m256 vec_sum = _mm256_setzero_ps(); for (size_t i = start; i < stop; i += 8) { - // Load 8 bytes == 64 bits -- zeros out top 8 bytes + // Load 8 bytes (uint8_t or int8_t) from both vectors __m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i)); __m128i vec_b = _mm_loadu_si64((__m64*)(b_ptr + i)); - // Zero extend 8bit to 32bit ints - __m256i a_ints = _mm256_cvtepu8_epi32(vec_a); - __m256i b_ints = _mm256_cvtepu8_epi32(vec_b); - - // Convert signed integers to floats + // Extend 8 bit to 32 bit ints + __m256i a_ints; + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + a_ints = _mm256_cvtepu8_epi32(vec_a); + } else if constexpr (std::same_as) { + // Sign extend int8_t to int32_t + a_ints = _mm256_cvtepi8_epi32(vec_a); + } + + // Conditionally convert based on the type of W::value_type + __m256i b_ints; + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + b_ints = _mm256_cvtepu8_epi32(vec_b); + } else if constexpr (std::same_as) { + // Sign extend int8_t to int32_t + b_ints = _mm256_cvtepi8_epi32(vec_b); + } + + // Convert the 32-bit integers to floats for both vectors __m256 a_floats = _mm256_cvtepi32_ps(a_ints); __m256 b_floats = _mm256_cvtepi32_ps(b_ints); diff --git a/src/include/detail/scoring/l2_distance.h b/src/include/detail/scoring/l2_distance.h index 8f5aa1429..fe9f125d9 100644 --- a/src/include/detail/scoring/l2_distance.h +++ b/src/include/detail/scoring/l2_distance.h @@ -78,11 +78,12 @@ inline float naive_sum_of_squares(const V& a, const W& b) { } /** - * Compute l2 distance between vector of float and vector of uint8_t + * Compute l2 distance between vector of float and vector of uint8_t or int8_t */ template requires std::same_as && - std::same_as + (std::same_as || + std::same_as) inline float naive_sum_of_squares(const V& a, const W& b) { size_t size_a = size(a); float sum = 0.0; @@ -94,11 +95,12 @@ inline float naive_sum_of_squares(const V& a, const W& b) { } /** - * Compute l2 distance between vector of uint8_t and vector of float + * Compute l2 distance between vector of uint8_t or int8_t and vector of float */ template - requires std::same_as && - std::same_as + requires(std::same_as || + std::same_as) && + std::same_as inline float naive_sum_of_squares(const V& a, const W& b) { size_t size_a = size(a); float sum = 0.0; @@ -110,11 +112,14 @@ inline float naive_sum_of_squares(const V& a, const W& b) { } /** - * Compute l2 distance between vector of uint8_t and vector of uint8_t + * Compute l2 distance between vector of uint8_t or int8_t and vector of uint8_t + * or int8_t */ template - requires std::same_as && - std::same_as + requires(std::same_as || + std::same_as) && + (std::same_as || + std::same_as) inline float naive_sum_of_squares(const V& a, const W& b) { size_t size_a = size(a); float sum = 0.0; diff --git a/src/include/detail/scoring/l2_distance_avx.h b/src/include/detail/scoring/l2_distance_avx.h index f4b695ea2..5a7529fc9 100644 --- a/src/include/detail/scoring/l2_distance_avx.h +++ b/src/include/detail/scoring/l2_distance_avx.h @@ -109,7 +109,8 @@ inline float avx2_sum_of_squares(const V& a, const W& b) { */ template requires std::same_as && - std::same_as + (std::same_as || + std::same_as) inline float avx2_sum_of_squares(const V& a, const W& b) { // @todo Align on 256 bit boundaries const size_t start = 0; @@ -117,25 +118,32 @@ inline float avx2_sum_of_squares(const V& a, const W& b) { const size_t stop = size_a - (size_a % 8); const float* a_ptr = a.data(); - const uint8_t* b_ptr = b.data(); + const auto* b_ptr = b.data(); __m256 vec_sum = _mm256_setzero_ps(); for (size_t i = start; i < stop; i += 8) { // Load 8 floats - __m256 a_floats = _mm256_loadu_ps(a_ptr + i + 0); + __m256 a_floats = _mm256_loadu_ps(a_ptr + i); // Load 8 bytes __m128i vec_b = _mm_loadu_si64((__m64*)(b_ptr + i)); - // Zero extend 8bit to 32bit ints - __m256i b_ints = _mm256_cvtepu8_epi32(vec_b); + // Extend 8 bit to 32 bit ints + __m256i b_ints; + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + b_ints = _mm256_cvtepu8_epi32(vec_b); + } else { + // Sign extend int8_t to int32_t + b_ints = _mm256_cvtepi8_epi32(vec_b); + } // Convert signed integers to floats __m256 b_floats = _mm256_cvtepi32_ps(b_ints); // Subtract floats - __m256i diff = _mm256_sub_ps(a_floats, b_floats); + __m256 diff = _mm256_sub_ps(a_floats, b_floats); // Square and add with fmadd vec_sum = _mm256_fmadd_ps(diff, diff, vec_sum); @@ -167,15 +175,16 @@ inline float avx2_sum_of_squares(const V& a, const W& b) { * uint8_t - float */ template - requires std::same_as && - std::same_as + requires(std::same_as || + std::same_as) && + std::same_as inline float avx2_sum_of_squares(const V& a, const W& b) { // @todo Align on 256 bit boundaries const size_t start = 0; const size_t size_a = size(a); const size_t stop = size_a - (size_a % 8); - const uint8_t* a_ptr = a.data(); + const auto* a_ptr = a.data(); const float* b_ptr = b.data(); __m256 vec_sum = _mm256_setzero_ps(); @@ -183,18 +192,22 @@ inline float avx2_sum_of_squares(const V& a, const W& b) { for (size_t i = start; i < stop; i += 8) { // Load 8 bytes __m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i)); + __m256 b_floats = _mm256_loadu_ps(b_ptr + i); - // Load 8 floats - __m256 b_floats = _mm256_loadu_ps(b_ptr + i + 0); - - // Zero extend 8bit to 32bit ints - __m256i a_ints = _mm256_cvtepu8_epi32(vec_a); + __m256i a_ints; + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + a_ints = _mm256_cvtepu8_epi32(vec_a); + } else { + // Sign extend int8_t to int32_t + a_ints = _mm256_cvtepi8_epi32(vec_a); + } // Convert signed integers to floats __m256 a_floats = _mm256_cvtepi32_ps(a_ints); // Subtract floats - __m256i diff = _mm256_sub_ps(a_floats, b_floats); + __m256 diff = _mm256_sub_ps(a_floats, b_floats); // Square and add with fmadd vec_sum = _mm256_fmadd_ps(diff, diff, vec_sum); @@ -226,16 +239,18 @@ inline float avx2_sum_of_squares(const V& a, const W& b) { * uint8_t - uint8_t */ template - requires std::same_as && - std::same_as + requires(std::same_as || + std::same_as) && + (std::same_as || + std::same_as) inline float avx2_sum_of_squares(const V& a, const W& b) { // @todo Align on 256 bit boundaries const size_t start = 0; const size_t size_a = size(a); const size_t stop = size_a - (size_a % 8); - const uint8_t* a_ptr = a.data(); - const uint8_t* b_ptr = b.data(); + const auto* a_ptr = a.data(); + const auto* b_ptr = b.data(); __m256 vec_sum = _mm256_setzero_ps(); @@ -244,36 +259,28 @@ inline float avx2_sum_of_squares(const V& a, const W& b) { __m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i)); __m128i vec_b = _mm_loadu_si64((__m64*)(b_ptr + i)); - // Zero extend 8bit to 32bit ints - __m256i a_ints = _mm256_cvtepu8_epi32(vec_a); - __m256i b_ints = _mm256_cvtepu8_epi32(vec_b); - - // Two alternatives for computing difference - 2nd seems faster -#if 0 - // Convert signed integers to floats - __m256 a_floats = _mm256_cvtepi32_ps(a_ints); - __m256 b_floats = _mm256_cvtepi32_ps(b_ints); - - // Subtract floats - __m256i diff = _mm256_sub_ps(a_floats, b_floats); -#else - // Subtract signed integers - __m256 i_diff = _mm256_sub_epi32(a_ints, b_ints); - - // Convert integers to floats + // Extend 8 bit to 32 bit ints + __m256i a_ints, b_ints; + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + a_ints = _mm256_cvtepu8_epi32(vec_a); + } else { + // Sign extend int8_t to int32_t + a_ints = _mm256_cvtepi8_epi32(vec_a); + } + + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + b_ints = _mm256_cvtepu8_epi32(vec_b); + } else { + // Sign extend int8_t to int32_t + b_ints = _mm256_cvtepi8_epi32(vec_b); + } + + __m256i i_diff = _mm256_sub_epi32(a_ints, b_ints); __m256 diff = _mm256_cvtepi32_ps(i_diff); -#endif - // Two alternatives for squaring and accumulating -- 2nd seems faster -#if 0 - // Square and add in two steps - __m256 diff_2 = _mm256_mul_ps(diff, diff); - vec_sum = _mm256_add_ps(vec_sum, diff_2); - -#else - // Square and add with fmadd vec_sum = _mm256_fmadd_ps(diff, diff, vec_sum); -#endif } // 8 to 4 @@ -346,14 +353,16 @@ inline float avx2_sum_of_squares(const V& a) { * uint8_t */ template - requires std::same_as + requires( + std::same_as || + std::same_as) inline float avx2_sum_of_squares(const V& a) { // @todo Align on 256 bit boundaries const size_t start = 0; const size_t size_a = size(a); const size_t stop = size_a - (size_a % 8); - const uint8_t* a_ptr = a.data(); + const auto* a_ptr = a.data(); __m256 vec_sum = _mm256_setzero_ps(); @@ -361,8 +370,15 @@ inline float avx2_sum_of_squares(const V& a) { // Load 8 bytes == 64 bits -- zeros out top 8 bytes __m128i vec_a = _mm_loadu_si64((__m64*)(a_ptr + i)); - // Zero extend 8bit to 32bit ints - __m256i a_ints = _mm256_cvtepu8_epi32(vec_a); + // Extend 8 bit to 32 bit ints + __m256i a_ints; + if constexpr (std::same_as) { + // Zero extend uint8_t to int32_t + a_ints = _mm256_cvtepu8_epi32(vec_a); + } else { + // Sign extend int8_t to int32_t + a_ints = _mm256_cvtepi8_epi32(vec_a); + } // Convert integers to floats __m256 a_floats = _mm256_cvtepi32_ps(a_ints); diff --git a/src/include/test/unit_inner_product_distance.cc b/src/include/test/unit_inner_product_distance.cc index 99adceb88..627f31df1 100644 --- a/src/include/test/unit_inner_product_distance.cc +++ b/src/include/test/unit_inner_product_distance.cc @@ -31,6 +31,7 @@ */ #include #include "detail/scoring/inner_product.h" +#include "detail/scoring/inner_product_avx.h" TEST_CASE("simple vectors", "[inner_product_distance]") { auto u = std::vector{1, 2, 3, 4}; @@ -85,16 +86,42 @@ TEST_CASE("simple vectors", "[inner_product_distance]") { #endif } -TEST_CASE( - "inner_distance: naive_inner_product with longer vectors", - "[inner_product_distance]") { +TEST_CASE("simple longer vectors", "[inner_product_distance]") { + size_t n = 100; + auto x = std::vector(n); + for (size_t i = 0; i < n; ++i) { + x[i] = i * 1.1f; + } + auto y = std::vector(n); + for (size_t i = 0; i < n; ++i) { + y[i] = i * 10.f; + } + float manual = 0.f; + for (size_t i = 0; i < n; ++i) { + manual += x[i] * y[i]; + } + + auto naive_xy = naive_inner_product(x, y); + auto unroll4_xy = unroll4_inner_product(x, y); + + CHECK(std::abs(naive_xy - manual) < 0.01); + CHECK(std::abs(unroll4_xy - manual) < 0.01); + +#ifdef __AVX2__ + auto avx2_xy = avx2_inner_product(x, y); + CHECK(std::abs(avx2_xy - manual) < 0.01); +#endif +} + +TEMPLATE_TEST_CASE( + "complex longer vectors", "[inner_product_distance]", int8_t, uint8_t) { // size_t n = GENERATE(1, 3, 127, 1021, 1024); // @todo generate a range of sizes (will require range of expected answers) - size_t n = GENERATE(127); + size_t n = GENERATE(55); - auto u = std::vector(n); - auto v = std::vector(n); + auto u = std::vector(n); + auto v = std::vector(n); auto x = std::vector(n); auto y = std::vector(n); @@ -107,69 +134,76 @@ TEST_CASE( std::iota(begin(z), end(z), 13); - auto a = naive_inner_product(x, y); - CHECK(std::abs(a - 785444.4375) < 0.01); - auto ax = naive_inner_product(y, x); - CHECK(ax == a); - CHECK(std::abs(ax - 785444.4375) < 0.01); + { + float expected = 72735.734375f; + auto naive_xy = naive_inner_product(x, y); + CHECK(std::abs(naive_xy - expected) < 0.01); + CHECK(naive_xy == naive_inner_product(y, x)); - auto b = unroll4_inner_product(x, y); - CHECK(std::abs(b - 785444.4375) < 0.01); - CHECK(a == b); - auto bx = unroll4_inner_product(y, x); - CHECK(bx == b); - CHECK(std::abs(bx - 785444.4375) < 0.01); - - auto a2 = naive_inner_product(u, v); - CHECK(a2 == 778764); - auto b2 = unroll4_inner_product(u, v); - CHECK(b2 == 778764); - CHECK(a2 == b2); - - auto a2x = naive_inner_product(v, u); - CHECK(a2x == 778764); - CHECK(a2 == a2x); - auto b2x = unroll4_inner_product(v, u); - CHECK(b2x == 778764); - CHECK(a2x == b2x); - - auto a3 = naive_inner_product(u, x); - CHECK(std::abs(a3 - 649615.1875) < 0.25); - auto b3 = unroll4_inner_product(u, x); - CHECK(std::abs(a3 - 649615.1875) < 0.25); - CHECK(std::abs(a3 - b3) < 0.25); - - auto a3x = naive_inner_product(x, u); - CHECK(std::abs(a3x - 649615.1875) < 0.25); - CHECK(a3 == a3x); - auto b3x = unroll4_inner_product(x, u); - CHECK(std::abs(a3x - 649615.1875) < 0.25); - CHECK(std::abs(a3x - b3x) < 0.25); + auto unroll4_xy = unroll4_inner_product(x, y); + CHECK(std::abs(unroll4_xy - expected) < 0.01); + CHECK(unroll4_xy == unroll4_inner_product(y, x)); #ifdef __AVX2__ + auto avx2_xy = avx2_inner_product(x, y); + CHECK(std::abs(avx2_xy - expected) < 0.07); + CHECK(avx2_xy == avx2_inner_product(y, x)); +#endif + } + { - auto a = avxs_inner_product(x, y); - CHECK(std::abs(a - 785444.4375) < 0.01); + auto naive_uv = naive_inner_product(u, v); + CHECK(naive_uv == 73260); + CHECK(naive_uv == naive_inner_product(v, u)); - auto ax = avx2_inner_product(y, x); - CHECK(ax == a); - CHECK(std::abs(ax - 785444.4375) < 0.01); + CHECK(naive_uv == unroll4_inner_product(u, v)); + CHECK(naive_uv == unroll4_inner_product(v, u)); - auto a2 = avx2_inner_product(u, v); - CHECK(a2 == 778764); +#ifdef __AVX2__ + CHECK(naive_uv == avx2_inner_product(u, v)); + CHECK(naive_uv == avx2_inner_product(v, u)); +#endif + } - auto a2x = avx2_inner_product(v, u); - CHECK(a2x == 778764); - CHECK(a2 == a2x); + { + float expected = 49289.7382812f; + auto naive_ux = naive_inner_product(u, x); + CHECK(std::abs(naive_ux - expected) < 0.05); + CHECK(naive_ux == naive_inner_product(x, u)); - auto a3 = avx2_inner_product(u, x); - CHECK(std::abs(a3 - 649615.1875) < 0.25); + auto unroll4_ux = unroll4_inner_product(u, x); + CHECK(std::abs(unroll4_ux - expected) < 0.05); - auto a3x = avx2_inner_product(x, u); - CHECK(std::abs(a3x - 649615.1875) < 0.25); - CHECK(a3 == a3x); + auto unroll4_xu = unroll4_inner_product(x, u); + CHECK(std::abs(unroll4_xu - expected) < 0.05); + +#ifdef __AVX2__ + auto avx2_ux = avx2_inner_product(u, x); + CHECK(std::abs(avx2_ux - expected) < 0.05); + + CHECK(avx2_ux == avx2_inner_product(x, u)); +#endif } + + { + float expected = 112568.59375f; + auto naive_vy = naive_inner_product(v, y); + CHECK(std::abs(naive_vy - expected) < 0.05); + CHECK(naive_vy == naive_inner_product(y, v)); + + auto unroll4_vy = unroll4_inner_product(v, y); + CHECK(std::abs(unroll4_vy - expected) < 0.05); + + auto unroll4_yv = unroll4_inner_product(y, v); + CHECK(std::abs(unroll4_yv - expected) < 0.05); + +#ifdef __AVX2__ + auto avx2_vy = avx2_inner_product(v, y); + CHECK(std::abs(avx2_vy - expected) < 0.05); + + CHECK(avx2_vy == avx2_inner_product(y, v)); #endif + } // @todo: inner_product does not yet have a partitioned version implemented yet. // Leaving code here for future reference diff --git a/src/include/test/unit_l2_distance.cc b/src/include/test/unit_l2_distance.cc index 8e958cdee..3f7551aae 100644 --- a/src/include/test/unit_l2_distance.cc +++ b/src/include/test/unit_l2_distance.cc @@ -31,9 +31,100 @@ */ #include +#include #include "detail/scoring/l2_distance.h" +#include "detail/scoring/l2_distance_avx.h" -TEMPLATE_TEST_CASE("naive_sum_of_squares", "[l2_distance]", int8_t, uint8_t) { +TEMPLATE_TEST_CASE("simple sum_of_squares", "[l2_distance]", int8_t, uint8_t) { + std::cout << std::fixed << std::setprecision(7); + + size_t n = GENERATE(55); + + auto u = std::vector(n); + auto v = std::vector(n); + auto x = std::vector(n); + auto y = std::vector(n); + + std::iota(begin(u), end(u), 0); + std::iota(begin(v), end(v), 5); + std::iota(begin(x), end(x), -3.14159); + std::iota(begin(y), end(y), 2.71828); + + { + float expected = 1888.5957031; + + auto naive_xy = naive_sum_of_squares(x, y); + CHECK(std::abs(naive_xy - expected) < 0.01); + CHECK(naive_xy == naive_sum_of_squares(y, x)); + + auto unroll4_xy = unroll4_sum_of_squares(x, y); + CHECK(std::abs(unroll4_xy - expected) < 0.01); + CHECK(unroll4_xy == unroll4_sum_of_squares(y, x)); + +#ifdef __AVX2__ + auto avx2_xy = avx2_sum_of_squares(x, y); + CHECK(std::abs(avx2_xy - expected) < 0.01); + CHECK(avx2_xy == avx2_sum_of_squares(y, x)); +#endif + } + + { + float expected = 1375.f; + + auto naive_uv = naive_sum_of_squares(u, v); + CHECK(std::abs(naive_uv - expected) < 0.01); + CHECK(naive_uv == naive_sum_of_squares(v, u)); + + auto unroll4_uv = unroll4_sum_of_squares(u, v); + CHECK(std::abs(unroll4_uv - expected) < 0.01); + CHECK(unroll4_uv == unroll4_sum_of_squares(v, u)); + +#ifdef __AVX2__ + auto avx2_uv = avx2_sum_of_squares(u, v); + CHECK(std::abs(avx2_uv - expected) < 0.01); + CHECK(avx2_uv == avx2_sum_of_squares(v, u)); +#endif + } + + { + float expected_ux = 542.8275146f; + + auto naive_ux = naive_sum_of_squares(u, x); + CHECK(std::abs(naive_ux - expected_ux) < 0.01); + CHECK(naive_ux == naive_sum_of_squares(x, u)); + + auto unroll4_ux = unroll4_sum_of_squares(u, x); + CHECK(std::abs(unroll4_ux - expected_ux) < 0.01); + CHECK(unroll4_ux == unroll4_sum_of_squares(x, u)); + +#ifdef __AVX2__ + auto avx2_ux = avx2_sum_of_squares(u, x); + CHECK(std::abs(avx2_ux - expected_ux) < 0.01); + CHECK(avx2_ux == avx2_sum_of_squares(x, u)); +#endif + } + + { + float expected_vy = 286.3432922f; + + auto naive_vy = naive_sum_of_squares(v, y); + CHECK(std::abs(naive_vy - expected_vy) < 0.01); + CHECK(naive_vy == naive_sum_of_squares(y, v)); + + auto unroll4_vy = unroll4_sum_of_squares(v, y); + CHECK(std::abs(unroll4_vy - expected_vy) < 0.01); + CHECK(unroll4_vy == unroll4_sum_of_squares(y, v)); + +#ifdef __AVX2__ + auto avx2_vy = avx2_sum_of_squares(v, y); + CHECK(std::abs(avx2_vy - expected_vy) < 0.01); + CHECK(avx2_vy == avx2_sum_of_squares(y, v)); +#endif + } +} + +TEMPLATE_TEST_CASE( + "start and stop sum_of_squares", "[l2_distance]", int8_t, uint8_t) { // size_t n = GENERATE(1, 3, 127, 1021, 1024); size_t n = GENERATE(127); diff --git a/src/include/test/unit_scoring.cc b/src/include/test/unit_scoring.cc index c696e4ee3..436ca446d 100644 --- a/src/include/test/unit_scoring.cc +++ b/src/include/test/unit_scoring.cc @@ -717,7 +717,7 @@ inline float sum_of_squares_avx2(const V& a, const W& b) { } TEST_CASE("avx2", "[scoring]") { - ColMajorMatrix rand_a{ + ColMajorMatrix rand_a{{ {0, 1, 2, 3, 4, 5, 6, 7, 3, 1, 4}, {8, 9, 10, 11, 12, 13, 14, 15, 1, 5, 9}, {16, 17, 18, 19, 20, 21, 22, 23, 2, 6, 5}, @@ -739,9 +739,9 @@ TEST_CASE("avx2", "[scoring]") { {48, 49, 50, 51, 52, 53, 54, 55, 8, 4, 6}, {56, 57, 58, 59, 60, 61, 62, 63, 2, 6, 4}, {46, 65, 66, 67, 68, 69, 70, 71, 3, 3, 8}, - }; + }}; - ColMajorMatrix rand_b{ + ColMajorMatrix rand_b{{ {136, 135, 134, 33, 132, 131, 130, 129, 3, 8, 3}, {128, 127, 126, 125, 124, 123, 122, 121, 3, 4, 6}, {120, 119, 118, 117, 116, 115, 114, 113, 2, 6, 4}, @@ -751,9 +751,9 @@ TEST_CASE("avx2", "[scoring]") { {88, 87, 86, 85, 84, 83, 82, 81, 3, 5, 6}, {80, 79, 78, 77, 76, 75, 74, 73, 2, 9, 5}, {72, 71, 70, 69, 68, 67, 66, 65, 1, 4, 1}, - }; + }}; - ColMajorMatrix rand_0{ + ColMajorMatrix rand_0{{ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, @@ -763,7 +763,7 @@ TEST_CASE("avx2", "[scoring]") { {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - }; + }}; float sum_0 = 0; for (size_t i = 0; i < num_vectors(rand_a); ++i) {