From c97cd6256cd18a748355432b9064892c139fce24 Mon Sep 17 00:00:00 2001 From: Vasily Shamporov Date: Mon, 11 Nov 2024 14:10:44 +0100 Subject: [PATCH] Add chunk tests --- .../src/nodes/kernels/scaled_attn/common.hpp | 4 ++ .../intel_cpu/tests/unit/CMakeLists.txt | 7 +++ .../tests/unit/paged_attn_cache_rotation.cpp | 57 +++++++++++++++++++ 3 files changed, 68 insertions(+) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index c307b338d1454f..7a2770e47377d9 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -9,6 +9,10 @@ #include #include +// TODO (vshampor): remove this +#define HAVE_AVX2 +#define HAVE_AVX512F + #include "openvino/core/type/bfloat16.hpp" #include "openvino/core/type/float16.hpp" diff --git a/src/plugins/intel_cpu/tests/unit/CMakeLists.txt b/src/plugins/intel_cpu/tests/unit/CMakeLists.txt index a193505ef3f228..fa66a3e4730d88 100644 --- a/src/plugins/intel_cpu/tests/unit/CMakeLists.txt +++ b/src/plugins/intel_cpu/tests/unit/CMakeLists.txt @@ -80,9 +80,16 @@ endif() if(ENABLE_AVX2) ov_avx2_optimization_flags(avx2_flags) + message("VSHAMPOR: passing AVX flags ${avx2_flags}") target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags}") endif() +if(ENABLE_AVX512F) + ov_avx512_optimization_flags(avx512_flags) + message("VSHAMPOR: passing AVX flags ${avx512_flags}") + target_compile_options(${TARGET_NAME} PRIVATE "${avx512_flags}") +endif() + # LTO set_target_properties(${TARGET_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO}) diff --git a/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp b/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp index 29c64caf4a19da..e86badc60911b4 100644 --- a/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp +++ b/src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp @@ -5,8 +5,14 @@ #include #include +#include #include +// TODO (vshampor): remove this, find a way to forward the compile flags to the unit tests +#define HAVE_AVX2 +#define HAVE_AVX512F + +#include "kernels/scaled_attn/common.hpp" #include "utils/plain_tensor.hpp" #include "nodes/kernels/scaled_attn/cache_rotation.hpp" @@ -171,21 +177,72 @@ enum class TargetInstructionSet { using CacheRotationHWKernelTest = ::testing::TestWithParam; +MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") { + if (ref_container.size() < n || arg.size() < n) return false; + if (ref_container.size() != arg.size()) return false; + + bool is_ok = true; + for (size_t i = 0; i < ref_container.size(); i++) + { + if (!::testing::ExplainMatchResult(::testing::FloatNear(float(arg[i]), abs_err), float(ref_container[i]), result_listener)) + { + *result_listener << " for element at idx " << i; + is_ok = false; + } + } + return is_ok; +} + TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) { auto instruction_set = GetParam(); + constexpr size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16; + using MemChunk = std::array; + MemChunk chunk_x = { -0.76777814, 0.97583583, -0.23619731, 0.19022397, 0.56691264, 0.64870757, 0.63334306, 1.97307894, + 0.72495168, 1.22328697, -0.6005607, 0.17189973, -0.92268487, 0.40205632, 0.85996431, 1.70078315}; + + MemChunk chunk_y = { 1.68812157, -0.90722836, 0.58474063, -0.64561766, 0.62651501, 1.55990472, 0.41571189, 0.38366555, + 0.09841767, 0.02218336, -0.07657361, 1.6062845, -1.08282323, -0.92034808, -1.48428038, 0.43501142}; + + MemChunk chunk_cos = { -0.87461971, 0.95630476, 0.08715574, 0.8480481, -0.9612617, 0.27563736, 0.97437006, 0.66913061, + -0.89100652, 0.98480775, -0.7313537, -0.2419219, 0.10452846, 0.70710678, -0.32556815, -0.2923717 }; + + MemChunk chunk_sin = { -0.48480962, -0.2923717, 0.9961947, 0.52991926, 0.27563736, -0.9612617, -0.22495105, 0.74314483, + 0.4539905, -0.17364818, -0.68199836, -0.97029573, -0.9945219, -0.70710678, -0.94551858, 0.95630476 }; + + MemChunk ref_chunk_cos = chunk_cos; + MemChunk ref_chunk_sin = chunk_sin; + + MemChunk ref_chunk_x = { 1.48993147, 0.66794854, -0.60310147, 0.50344431, -0.71764235, 1.6782847, 0.71062535, 1.03512844, + -0.69061736, 1.20855459, 0.38699921, 1.51698468, -1.17333824, -0.36648762, -1.68339166, -0.91326436 }; + + MemChunk ref_chunk_y = { -1.10423816, -1.15289358, -0.184335, -0.44671148, -0.44598258, -0.19360973, 0.26258603, 1.72300577, + 0.24143039, -0.19057521, 0.46558381, -0.55538896, 0.80444446, -0.93508112, -0.32987781, 1.49928198 }; + + + size_t vec_len_in_elts = 0; switch(instruction_set) { case TargetInstructionSet::AVX2: + vec_len_in_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2; + rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), vec_len_in_elts, /* is_underutilizing = */ false); break; case TargetInstructionSet::AVX512: + vec_len_in_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx512; + rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), vec_len_in_elts, /* is_underutilizing = */false); break; default: FAIL() << "unknown target instruction set"; } + EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, 1e-6, vec_len_in_elts)); + EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, 1e-6, vec_len_in_elts)); + + EXPECT_EQ(chunk_cos, ref_chunk_cos); + EXPECT_EQ(chunk_sin, ref_chunk_sin); } +INSTANTIATE_TEST_SUITE_P(VariousInstructionSets, CacheRotationHWKernelTest, ::testing::Values{TargetInstructionSet::AVX2, TargetInstructionSet::AVX512}); TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) { auto raw_cache_mem_ptr = this->cache_mem_ptr.get();