Skip to content

Commit

Permalink
Add chunk tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 11, 2024
1 parent 929e5ba commit c97cd62
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include <vector>
#include <cassert>

// TODO (vshampor): remove this
#define HAVE_AVX2
#define HAVE_AVX512F

#include "openvino/core/type/bfloat16.hpp"
#include "openvino/core/type/float16.hpp"

Expand Down
7 changes: 7 additions & 0 deletions src/plugins/intel_cpu/tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,16 @@ endif()

if(ENABLE_AVX2)
ov_avx2_optimization_flags(avx2_flags)
message("VSHAMPOR: passing AVX flags ${avx2_flags}")
target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags}")
endif()

if(ENABLE_AVX512F)
ov_avx512_optimization_flags(avx512_flags)
message("VSHAMPOR: passing AVX flags ${avx512_flags}")
target_compile_options(${TARGET_NAME} PRIVATE "${avx512_flags}")
endif()

# LTO
set_target_properties(${TARGET_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO})

Expand Down
57 changes: 57 additions & 0 deletions src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@

#include <string>
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <memory>

// TODO (vshampor): remove this, find a way to forward the compile flags to the unit tests
#define HAVE_AVX2
#define HAVE_AVX512F

#include "kernels/scaled_attn/common.hpp"
#include "utils/plain_tensor.hpp"
#include "nodes/kernels/scaled_attn/cache_rotation.hpp"

Expand Down Expand Up @@ -171,21 +177,72 @@ enum class TargetInstructionSet {

using CacheRotationHWKernelTest = ::testing::TestWithParam<TargetInstructionSet>;

MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") {
if (ref_container.size() < n || arg.size() < n) return false;
if (ref_container.size() != arg.size()) return false;

bool is_ok = true;
for (size_t i = 0; i < ref_container.size(); i++)
{
if (!::testing::ExplainMatchResult(::testing::FloatNear(float(arg[i]), abs_err), float(ref_container[i]), result_listener))
{
*result_listener << " for element at idx " << i;
is_ok = false;
}
}
return is_ok;
}

TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) {
auto instruction_set = GetParam();

constexpr size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16;
using MemChunk = std::array<float, MAX_CHUNK_SIZE_IN_ELEMENTS>;
MemChunk chunk_x = { -0.76777814, 0.97583583, -0.23619731, 0.19022397, 0.56691264, 0.64870757, 0.63334306, 1.97307894,
0.72495168, 1.22328697, -0.6005607, 0.17189973, -0.92268487, 0.40205632, 0.85996431, 1.70078315};

MemChunk chunk_y = { 1.68812157, -0.90722836, 0.58474063, -0.64561766, 0.62651501, 1.55990472, 0.41571189, 0.38366555,
0.09841767, 0.02218336, -0.07657361, 1.6062845, -1.08282323, -0.92034808, -1.48428038, 0.43501142};

MemChunk chunk_cos = { -0.87461971, 0.95630476, 0.08715574, 0.8480481, -0.9612617, 0.27563736, 0.97437006, 0.66913061,
-0.89100652, 0.98480775, -0.7313537, -0.2419219, 0.10452846, 0.70710678, -0.32556815, -0.2923717 };

MemChunk chunk_sin = { -0.48480962, -0.2923717, 0.9961947, 0.52991926, 0.27563736, -0.9612617, -0.22495105, 0.74314483,
0.4539905, -0.17364818, -0.68199836, -0.97029573, -0.9945219, -0.70710678, -0.94551858, 0.95630476 };

MemChunk ref_chunk_cos = chunk_cos;
MemChunk ref_chunk_sin = chunk_sin;

MemChunk ref_chunk_x = { 1.48993147, 0.66794854, -0.60310147, 0.50344431, -0.71764235, 1.6782847, 0.71062535, 1.03512844,
-0.69061736, 1.20855459, 0.38699921, 1.51698468, -1.17333824, -0.36648762, -1.68339166, -0.91326436 };

MemChunk ref_chunk_y = { -1.10423816, -1.15289358, -0.184335, -0.44671148, -0.44598258, -0.19360973, 0.26258603, 1.72300577,
0.24143039, -0.19057521, 0.46558381, -0.55538896, 0.80444446, -0.93508112, -0.32987781, 1.49928198 };


size_t vec_len_in_elts = 0;
switch(instruction_set) {
case TargetInstructionSet::AVX2:
vec_len_in_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2;
rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), vec_len_in_elts, /* is_underutilizing = */ false);
break;
case TargetInstructionSet::AVX512:
vec_len_in_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx512;
rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), vec_len_in_elts, /* is_underutilizing = */false);
break;
default:
FAIL() << "unknown target instruction set";
}

EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, 1e-6, vec_len_in_elts));
EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, 1e-6, vec_len_in_elts));

EXPECT_EQ(chunk_cos, ref_chunk_cos);
EXPECT_EQ(chunk_sin, ref_chunk_sin);

}

INSTANTIATE_TEST_SUITE_P(VariousInstructionSets, CacheRotationHWKernelTest, ::testing::Values{TargetInstructionSet::AVX2, TargetInstructionSet::AVX512});

TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) {
auto raw_cache_mem_ptr = this->cache_mem_ptr.get();
Expand Down

0 comments on commit c97cd62

Please sign in to comment.