Skip to content

Commit

Permalink
[CPU] fix qkv_proj/mlp jit kernel's win32 support (#28915)
Browse files Browse the repository at this point in the history
### Details:
- *MSVC on win32 platform has x64 calling conversion different from
Linux, this change added a fix to jit kernel to support win32 system*
 - 
### Tickets:
 - *CVS-151107*
  • Loading branch information
usstq authored Feb 14, 2025
1 parent e550a08 commit 73e43b5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 24 deletions.
54 changes: 44 additions & 10 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ void MKernel::generate_2x2() {
Xbyak::Reg64 reg_A_addr = abi_param2;
Xbyak::Reg64 reg_A_stride = abi_param3;
Xbyak::Reg64 reg_B_addr = abi_param4;
#ifdef _WIN32
Xbyak::Reg64 reg_C_addr = rdi;
Xbyak::Reg64 reg_C_stride = rsi;
push(rdi);
push(rsi);
#else
Xbyak::Reg64 reg_C_addr = abi_param5;
Xbyak::Reg64 reg_C_stride = abi_param6;
#endif

Xbyak::Reg64 reg_ktiles = rax;
Xbyak::Reg64 reg_B_stride = r10;
Expand Down Expand Up @@ -139,6 +146,10 @@ void MKernel::generate_2x2() {
tilestored(ptr[reg_C_addr + reg_C_stride + 64], tmmC11);

pop(reg_prefetch);
#ifdef _WIN32
pop(rsi);
pop(rdi);
#endif
ret();
}

Expand Down Expand Up @@ -173,8 +184,15 @@ void MKernel::generate_1x2() {
Xbyak::Reg64 reg_A_addr = abi_param2;
Xbyak::Reg64 reg_A_stride = abi_param3;
Xbyak::Reg64 reg_B_addr = abi_param4;
#ifdef _WIN32
Xbyak::Reg64 reg_C_addr = rdi;
Xbyak::Reg64 reg_C_stride = rsi;
push(rdi);
push(rsi);
#else
Xbyak::Reg64 reg_C_addr = abi_param5;
Xbyak::Reg64 reg_C_stride = abi_param6;
#endif

Xbyak::Reg64 reg_ktiles = rax;
Xbyak::Reg64 reg_B_stride = r10;
Expand Down Expand Up @@ -253,7 +271,10 @@ void MKernel::generate_1x2() {
tilestored(ptr[reg_C_addr + reg_C_stride + 64], tmmC01);
}
L(skip_store);

#ifdef _WIN32
pop(rsi);
pop(rdi);
#endif
ret();
}

Expand Down Expand Up @@ -599,11 +620,18 @@ void GateUpCombine::generate() {

void ReduceAdd2bh::generate() {
if (m_do_reduce2) {
Xbyak::Reg64 src0 = abi_param1;
Xbyak::Reg64 src1 = abi_param2;
Xbyak::Reg64 dst = abi_param3;
Xbyak::Reg64 prefetch_dst = abi_param4;
Xbyak::Reg64 BN = abi_param5;
Xbyak::Reg64 src0 = rdx;
Xbyak::Reg64 src1 = r8;
Xbyak::Reg64 dst = r9;
Xbyak::Reg64 prefetch_dst = r10;
Xbyak::Reg64 BN = r11;

mov(src0, ptr[abi_param1 + offsetof(CallArgs, src0)]);
mov(src1, ptr[abi_param1 + offsetof(CallArgs, src1)]);
mov(dst, ptr[abi_param1 + offsetof(CallArgs, dst)]);
mov(prefetch_dst, ptr[abi_param1 + offsetof(CallArgs, prefetch_dst)]);
mov(BN, ptr[abi_param1 + offsetof(CallArgs, num_cols)]);

Xbyak::Reg64 loop_i = rax;

Xbyak::Label loop_begin;
Expand Down Expand Up @@ -635,10 +663,16 @@ void ReduceAdd2bh::generate() {

ret();
} else {
Xbyak::Reg64 src0 = abi_param1;
Xbyak::Reg64 dst = abi_param2;
Xbyak::Reg64 prefetch_dst = abi_param3;
Xbyak::Reg64 BN = abi_param4;
Xbyak::Reg64 src0 = rdx;
Xbyak::Reg64 dst = r9;
Xbyak::Reg64 prefetch_dst = r10;
Xbyak::Reg64 BN = r11;

mov(src0, ptr[abi_param1 + offsetof(CallArgs, src0)]);
mov(dst, ptr[abi_param1 + offsetof(CallArgs, dst)]);
mov(prefetch_dst, ptr[abi_param1 + offsetof(CallArgs, prefetch_dst)]);
mov(BN, ptr[abi_param1 + offsetof(CallArgs, num_cols)]);

Xbyak::Reg64 loop_i = rax;

Xbyak::Label loop_begin;
Expand Down
31 changes: 23 additions & 8 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,26 +526,41 @@ class ReduceAdd2bh : public dnnl::impl::cpu::x64::jit_generator {

void generate() override;

struct CallArgs {
float* src0;
float* src1;
int16_t * dst;
int16_t * prefetch_dst;
int64_t num_cols;
};
// add two float input eltwise and convert to bf16 : ConvertFP32toBF16(src0 + src1)
void
call(float* src0, float* src1, size_t src_stride, void* pf16_dst, size_t dst_stride, int num_rows, int num_cols) {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) {
CallArgs args;
args.src0 = src0;
args.src1 = src1;
args.dst = reinterpret_cast<int16_t*>(pf16_dst);
args.num_cols = num_cols;
for (int m = 0; m < num_rows; m++, args.src0 += src_stride, args.src1 += src_stride, args.dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, src1, dst, prefetch_dst, num_cols);
args.prefetch_dst = (m + 2 < num_rows) ? (args.dst + 2 * dst_stride) : (args.dst);

(*this)(&args);
}
}

// convert tensor to bf16: ConvertFP32toBF16(src0)
void call(float* src0, size_t src_stride, void* pf16_dst, size_t dst_stride, int num_rows, int num_cols) {
auto* dst = reinterpret_cast<int16_t*>(pf16_dst);
for (int m = 0; m < num_rows; m++, src0 += src_stride, dst += dst_stride) {
CallArgs args;
args.src0 = src0;
args.dst = reinterpret_cast<int16_t*>(pf16_dst);
args.num_cols = num_cols;
for (int m = 0; m < num_rows; m++, args.src0 += src_stride, args.dst += dst_stride) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst);
(*this)(src0, dst, prefetch_dst, num_cols);
args.prefetch_dst = (m + 2 < num_rows) ? (args.dst + 2 * dst_stride) : (args.dst);
(*this)(&args);
}
}
};
Expand Down
7 changes: 1 addition & 6 deletions src/plugins/intel_cpu/src/nodes/qkv_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase {
asym);
}
// compress accumulation result into target
for (int mi = 0; mi < BM; mi++, src += stride_src, dst += stride_dst) {
// the prefetch distance is increased to ensure by the time store happens
// prefetch has done and no HW prefetcher is triggered
auto* prefetch_dst = (mi + 2 < BM) ? (dst + 2 * stride_dst) : (dst);
jit_cvt(src, dst, prefetch_dst, work.BN);
}
jit_cvt.call(src, stride_src, dst, stride_dst, BM, work.BN);
}
});
m += BM;
Expand Down

0 comments on commit 73e43b5

Please sign in to comment.