Skip to content

Commit

Permalink
Add testing for chunk rotation in FP16 and BF16
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 12, 2024
1 parent 7722675 commit ed46cfe
Showing 1 changed file with 78 additions and 57 deletions.
135 changes: 78 additions & 57 deletions src/plugins/intel_cpu/tests/unit/paged_attn_cache_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,23 @@ void compare_with_tolerance(const Rank3Matrix<T>& test_data, const Rank3Matrix<T
}


template <class T>
static T get_tolerance() {
return T{};
}

template <>
float get_tolerance<float>() { return 1e-6; };

template <>
ov::float16 get_tolerance<ov::float16>() { return ov::float16{1e-3}; };

template <>
ov::bfloat16 get_tolerance<ov::bfloat16>() { return ov::bfloat16{2e-2}; };

template<class TypeParam>
class CacheRotationKernelTest : public ::testing::Test {
public:
TypeParam get_tolerance() { return TypeParam(1e-6); }
void SetUp() override {
Rank3Matrix<TypeParam> values_before_rotation = {
{
Expand Down Expand Up @@ -154,12 +167,6 @@ class CacheRotationKernelTest : public ::testing::Test {
};
};

template<>
ov::float16 CacheRotationKernelTest<ov::float16>::get_tolerance() { return ov::float16(1e-3); }

template<>
ov::bfloat16 CacheRotationKernelTest<ov::bfloat16>::get_tolerance() { return ov::bfloat16(2e-2); }

using OV_FP_TYPES = ::testing::Types<float, ov::float16, ov::bfloat16>;

TYPED_TEST_SUITE(CacheRotationKernelTest, OV_FP_TYPES);
Expand All @@ -171,16 +178,14 @@ TYPED_TEST(CacheRotationKernelTest, SWBlockRotationGivesReferenceResults) {
rotate_kv_cache_block_sw(raw_cache_mem_ptr, raw_rotation_coefficients_mem_ptr, this->num_heads, this->block_size, this->embedding_size);

auto test_values_after_rotation = get_matrix_from_mem(this->cache_mem_ptr, this->num_heads, this->block_size, this->embedding_size);
compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, this->get_tolerance());
compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance<TypeParam>());
}

enum class TargetInstructionSet {
AVX2,
AVX512
};

using CacheRotationHWKernelTest = ::testing::TestWithParam<std::tuple<TargetInstructionSet, size_t>>;

MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") {
if (ref_container.size() < n || arg.size() < n) return false;
if (ref_container.size() != arg.size()) return false;
Expand All @@ -197,55 +202,71 @@ MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") {
return is_ok;
}

TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) {
auto instruction_set = std::get<0>(GetParam());
auto num_elements_to_process = std::get<1>(GetParam());

constexpr size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16;
using MemChunk = std::array<float, MAX_CHUNK_SIZE_IN_ELEMENTS>;
MemChunk chunk_x = { -0.76777814, 0.97583583, -0.23619731, 0.19022397, 0.56691264, 0.64870757, 0.63334306, 1.97307894,
0.72495168, 1.22328697, -0.6005607, 0.17189973, -0.92268487, 0.40205632, 0.85996431, 1.70078315};

MemChunk chunk_y = { 1.68812157, -0.90722836, 0.58474063, -0.64561766, 0.62651501, 1.55990472, 0.41571189, 0.38366555,
0.09841767, 0.02218336, -0.07657361, 1.6062845, -1.08282323, -0.92034808, -1.48428038, 0.43501142};

MemChunk chunk_cos = { -0.87461971, 0.95630476, 0.08715574, 0.8480481, -0.9612617, 0.27563736, 0.97437006, 0.66913061,
-0.89100652, 0.98480775, -0.7313537, -0.2419219, 0.10452846, 0.70710678, -0.32556815, -0.2923717 };

MemChunk chunk_sin = { -0.48480962, -0.2923717, 0.9961947, 0.52991926, 0.27563736, -0.9612617, -0.22495105, 0.74314483,
0.4539905, -0.17364818, -0.68199836, -0.97029573, -0.9945219, -0.70710678, -0.94551858, 0.95630476 };

MemChunk ref_chunk_cos = chunk_cos;
MemChunk ref_chunk_sin = chunk_sin;

MemChunk ref_chunk_x = { 1.48993147, 0.66794854, -0.60310147, 0.50344431, -0.71764235, 1.6782847, 0.71062535, 1.03512844,
-0.69061736, 1.20855459, 0.38699921, 1.51698468, -1.17333824, -0.36648762, -1.68339166, -0.91326436 };

MemChunk ref_chunk_y = { -1.10423816, -1.15289358, -0.184335, -0.44671148, -0.44598258, -0.19360973, 0.26258603, 1.72300577,
0.24143039, -0.19057521, 0.46558381, -0.55538896, 0.80444446, -0.93508112, -0.32987781, 1.49928198 };

// unprocessed elements should remain untouched
std::copy(chunk_x.begin() + num_elements_to_process, chunk_x.end(), ref_chunk_x.begin() + num_elements_to_process);
std::copy(chunk_y.begin() + num_elements_to_process, chunk_y.end(), ref_chunk_y.begin() + num_elements_to_process);

switch(instruction_set) {
using namespace ov::Extensions::Cpu::XARCH;
case TargetInstructionSet::AVX2:
rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx2);
break;
case TargetInstructionSet::AVX512:
rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx512);
break;
default:
FAIL() << "unknown target instruction set";
}
class CacheRotationHWKernelTest: public ::testing::TestWithParam<std::tuple<TargetInstructionSet, size_t>> {
protected:
constexpr static size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16;
template<class T>
using MemChunk = std::array<T, MAX_CHUNK_SIZE_IN_ELEMENTS>;

template<class T>
void test_chunk_rotation_for_type() {
auto instruction_set = std::get<0>(GetParam());
auto num_elements_to_process = std::get<1>(GetParam());



MemChunk<T> chunk_x = { -0.76777814, 0.97583583, -0.23619731, 0.19022397, 0.56691264, 0.64870757, 0.63334306, 1.97307894,
0.72495168, 1.22328697, -0.6005607, 0.17189973, -0.92268487, 0.40205632, 0.85996431, 1.70078315};

EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, 1e-6, num_elements_to_process));
EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, 1e-6, num_elements_to_process));
MemChunk<T> chunk_y = { 1.68812157, -0.90722836, 0.58474063, -0.64561766, 0.62651501, 1.55990472, 0.41571189, 0.38366555,
0.09841767, 0.02218336, -0.07657361, 1.6062845, -1.08282323, -0.92034808, -1.48428038, 0.43501142};

EXPECT_EQ(chunk_cos, ref_chunk_cos);
EXPECT_EQ(chunk_sin, ref_chunk_sin);
MemChunk<float> chunk_cos = { -0.87461971, 0.95630476, 0.08715574, 0.8480481, -0.9612617, 0.27563736, 0.97437006, 0.66913061,
-0.89100652, 0.98480775, -0.7313537, -0.2419219, 0.10452846, 0.70710678, -0.32556815, -0.2923717 };

MemChunk<float> chunk_sin = { -0.48480962, -0.2923717, 0.9961947, 0.52991926, 0.27563736, -0.9612617, -0.22495105, 0.74314483,
0.4539905, -0.17364818, -0.68199836, -0.97029573, -0.9945219, -0.70710678, -0.94551858, 0.95630476 };

MemChunk<float> ref_chunk_cos = chunk_cos;
MemChunk<float> ref_chunk_sin = chunk_sin;

MemChunk<T> ref_chunk_x = { 1.48993147, 0.66794854, -0.60310147, 0.50344431, -0.71764235, 1.6782847, 0.71062535, 1.03512844,
-0.69061736, 1.20855459, 0.38699921, 1.51698468, -1.17333824, -0.36648762, -1.68339166, -0.91326436 };

MemChunk<T> ref_chunk_y = { -1.10423816, -1.15289358, -0.184335, -0.44671148, -0.44598258, -0.19360973, 0.26258603, 1.72300577,
0.24143039, -0.19057521, 0.46558381, -0.55538896, 0.80444446, -0.93508112, -0.32987781, 1.49928198 };

// unprocessed elements should remain untouched
std::copy(chunk_x.begin() + num_elements_to_process, chunk_x.end(), ref_chunk_x.begin() + num_elements_to_process);
std::copy(chunk_y.begin() + num_elements_to_process, chunk_y.end(), ref_chunk_y.begin() + num_elements_to_process);

switch(instruction_set) {
using namespace ov::Extensions::Cpu::XARCH;
case TargetInstructionSet::AVX2:
rotate_kv_cache_chunk_avx2(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx2);
break;
case TargetInstructionSet::AVX512:
rotate_kv_cache_chunk_avx512(chunk_x.data(), chunk_y.data(), chunk_cos.data(), chunk_sin.data(), num_elements_to_process, /* is_underutilizing = */ num_elements_to_process < vec_len_f32_avx512);
break;
default:
FAIL() << "unknown target instruction set";
}

std::string type_name = ov::element::from<T>().to_string();

EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, get_tolerance<T>(), num_elements_to_process)) << ", element type is: " << type_name;
EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, get_tolerance<T>(), num_elements_to_process)) << ", element type is: " << type_name;

EXPECT_EQ(chunk_cos, ref_chunk_cos) << ", element type is: " << type_name;
EXPECT_EQ(chunk_sin, ref_chunk_sin) << ", element type is: " << type_name;
}
};


TEST_P(CacheRotationHWKernelTest, HWChunkRotationGivesReferenceResults) {
test_chunk_rotation_for_type<float>();
test_chunk_rotation_for_type<ov::float16>();
test_chunk_rotation_for_type<ov::bfloat16>();
}

auto TEST_STRUCT_TO_NAME_FN = [](const testing::TestParamInfo<CacheRotationHWKernelTest::ParamType>& info) {
Expand All @@ -269,6 +290,6 @@ TYPED_TEST(CacheRotationKernelTest, HWBlockRotationGivesReferenceResults) {
rotate_kv_cache_block_hw(raw_cache_mem_ptr, raw_rotation_coefficients_mem_ptr, this->num_heads, this->block_size, this->embedding_size);

auto test_values_after_rotation = get_matrix_from_mem(this->cache_mem_ptr, this->num_heads, this->block_size, this->embedding_size);
compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, this->get_tolerance());
compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance<TypeParam>());
}

0 comments on commit ed46cfe

Please sign in to comment.