Skip to content

Commit

Permalink
[CPU]fix code style
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang Yi3 <[email protected]>
  • Loading branch information
zhangYiIntel committed Dec 10, 2024
1 parent 91ebc09 commit 685f263
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 88 deletions.
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
" for property key ",
key,
". Expected only unsinged integer numbers");
}
}
} else if (key == ov::cache_encryption_callbacks.name()) {
try {
auto encryption_callbacks = val.as<EncryptionCallbacks>();
Expand Down
102 changes: 73 additions & 29 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ static void find_minmax(const T* src, size_t n, float& min, float& max) {
max = std::max(max, tmp);
min = std::min(min, tmp);
}

}

template<typename T>
Expand Down Expand Up @@ -398,7 +397,10 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
});
}

template <typename T, ov::element::Type_t KEY_DST_PREC, ov::element::Type_t VALUE_DST_PREC, typename std::enable_if<VALUE_DST_PREC == ov::element::u8, bool>::type = true>
template <typename T,
ov::element::Type_t KEY_DST_PREC,
ov::element::Type_t VALUE_DST_PREC,
typename std::enable_if<VALUE_DST_PREC == ov::element::u8, bool>::type = true>
static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const ov::intel_cpu::PlainTensor& v_src,
const ov::intel_cpu::PlainTensor& k_dst,
Expand All @@ -417,27 +419,48 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
auto block_offset = slot % block_size;
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)|
for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) {
auto p_k = reinterpret_cast<float*>(k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number, h, block_offset, dst_offset));
for (size_t src_offset = 0, dst_offset = 0; src_offset < S;
src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) {
auto p_k = reinterpret_cast<float*>(
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset));
quant_u8(k_src.ptr<T>(b, h, m, src_offset),
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float),
_key_group_size,
p_k[0],
p_k[1]);
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset) +
sizeof(float) + sizeof(float),
_key_group_size,
p_k[0],
p_k[1]);
}

for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) {
auto p_v = reinterpret_cast<float*>(v_dst.ptr<typename ov::element_type_traits<VALUE_DST_PREC>::value_type>(block_number, h, block_offset, dst_offset));

for (size_t src_offset = 0, dst_offset = 0; src_offset < SV;
src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) {
auto p_v = reinterpret_cast<float*>(
v_dst.ptr<typename ov::element_type_traits<VALUE_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset));
quant_u8(v_src.ptr<T>(b, h, m, src_offset),
v_dst.ptr<typename ov::element_type_traits<VALUE_DST_PREC>::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float),
_value_group_size,
p_v[0],
p_v[1]);
v_dst.ptr<typename ov::element_type_traits<VALUE_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset) +
sizeof(float) + sizeof(float),
_value_group_size,
p_v[0],
p_v[1]);
}
});
}

template <typename T, ov::element::Type_t KEY_DST_PREC, ov::element::Type_t VALUE_DST_PREC, typename std::enable_if<VALUE_DST_PREC == ov::element::u4, bool>::type = true>
template <typename T,
ov::element::Type_t KEY_DST_PREC,
ov::element::Type_t VALUE_DST_PREC,
typename std::enable_if<VALUE_DST_PREC == ov::element::u4, bool>::type = true>
static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const ov::intel_cpu::PlainTensor& v_src,
const ov::intel_cpu::PlainTensor& k_dst,
Expand All @@ -457,13 +480,22 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
auto block_offset = slot % block_size;
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)|
for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) {
auto p_k = reinterpret_cast<float*>(k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number, h, block_offset, dst_offset));
for (size_t src_offset = 0, dst_offset = 0; src_offset < S;
src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) {
auto p_k = reinterpret_cast<float*>(
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset));
quant_u8(k_src.ptr<T>(b, h, m, src_offset),
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float),
_key_group_size,
p_k[0],
p_k[1]);
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset) +
sizeof(float) + sizeof(float),
_key_group_size,
p_k[0],
p_k[1]);
}

for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size,
Expand All @@ -480,7 +512,10 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
});
}

template <typename T, ov::element::Type_t KEY_DST_PREC, ov::element::Type_t VALUE_DST_PREC, typename std::enable_if<VALUE_DST_PREC == ov::element::i4, bool>::type = true>
template <typename T,
ov::element::Type_t KEY_DST_PREC,
ov::element::Type_t VALUE_DST_PREC,
typename std::enable_if<VALUE_DST_PREC == ov::element::i4, bool>::type = true>
static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const ov::intel_cpu::PlainTensor& v_src,
const ov::intel_cpu::PlainTensor& k_dst,
Expand All @@ -500,13 +535,22 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
auto block_offset = slot % block_size;
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)|
for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) {
auto p_k = reinterpret_cast<float*>(k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number, h, block_offset, dst_offset));
for (size_t src_offset = 0, dst_offset = 0; src_offset < S;
src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) {
auto p_k = reinterpret_cast<float*>(
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset));
quant_u8(k_src.ptr<T>(b, h, m, src_offset),
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float),
_key_group_size,
p_k[0],
p_k[1]);
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset) +
sizeof(float) + sizeof(float),
_key_group_size,
p_k[0],
p_k[1]);
}

for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale
// (q - zp) * scale
v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale);
v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale);

__m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
__m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8);
__m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half);
Expand All @@ -106,7 +105,7 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale
// q - zp
v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp);
v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp);

v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale);
v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale);

Expand Down Expand Up @@ -206,7 +205,7 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale
#endif
auto extract_half_byte = [&](uint8_t val, bool high_half) -> int8_t {
uint8_t shift = high_half ? 0 : 4;
return float((val >> shift) & 0x000F);
return static_cast<float>((val >> shift) & 0x000F);
};
for (; i < n; ++i) {
float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2));
Expand Down
Loading

0 comments on commit 685f263

Please sign in to comment.