From 73e43b5ace9326c3e4f8c60d70745506ec5c8d49 Mon Sep 17 00:00:00 2001 From: Tingqian Li Date: Fri, 14 Feb 2025 11:32:44 +0800 Subject: [PATCH] [CPU] fix qkv_proj/mlp jit kernel's win32 support (#28915) ### Details: - *MSVC on win32 platform has x64 calling conversion different from Linux, this change added a fix to jit kernel to support win32 system* - ### Tickets: - *CVS-151107* --- .../src/nodes/kernels/x64/mlp_kernel.cpp | 54 +++++++++++++++---- .../src/nodes/kernels/x64/mlp_kernel.hpp | 31 ++++++++--- src/plugins/intel_cpu/src/nodes/qkv_proj.cpp | 7 +-- 3 files changed, 68 insertions(+), 24 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp index 1d1b9e138b5232..97741c41913998 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp @@ -21,8 +21,15 @@ void MKernel::generate_2x2() { Xbyak::Reg64 reg_A_addr = abi_param2; Xbyak::Reg64 reg_A_stride = abi_param3; Xbyak::Reg64 reg_B_addr = abi_param4; +#ifdef _WIN32 + Xbyak::Reg64 reg_C_addr = rdi; + Xbyak::Reg64 reg_C_stride = rsi; + push(rdi); + push(rsi); +#else Xbyak::Reg64 reg_C_addr = abi_param5; Xbyak::Reg64 reg_C_stride = abi_param6; +#endif Xbyak::Reg64 reg_ktiles = rax; Xbyak::Reg64 reg_B_stride = r10; @@ -139,6 +146,10 @@ void MKernel::generate_2x2() { tilestored(ptr[reg_C_addr + reg_C_stride + 64], tmmC11); pop(reg_prefetch); +#ifdef _WIN32 + pop(rsi); + pop(rdi); +#endif ret(); } @@ -173,8 +184,15 @@ void MKernel::generate_1x2() { Xbyak::Reg64 reg_A_addr = abi_param2; Xbyak::Reg64 reg_A_stride = abi_param3; Xbyak::Reg64 reg_B_addr = abi_param4; +#ifdef _WIN32 + Xbyak::Reg64 reg_C_addr = rdi; + Xbyak::Reg64 reg_C_stride = rsi; + push(rdi); + push(rsi); +#else Xbyak::Reg64 reg_C_addr = abi_param5; Xbyak::Reg64 reg_C_stride = abi_param6; +#endif Xbyak::Reg64 reg_ktiles = rax; Xbyak::Reg64 reg_B_stride = r10; @@ -253,7 +271,10 @@ void MKernel::generate_1x2() { tilestored(ptr[reg_C_addr + reg_C_stride + 64], tmmC01); } L(skip_store); - +#ifdef _WIN32 + pop(rsi); + pop(rdi); +#endif ret(); } @@ -599,11 +620,18 @@ void GateUpCombine::generate() { void ReduceAdd2bh::generate() { if (m_do_reduce2) { - Xbyak::Reg64 src0 = abi_param1; - Xbyak::Reg64 src1 = abi_param2; - Xbyak::Reg64 dst = abi_param3; - Xbyak::Reg64 prefetch_dst = abi_param4; - Xbyak::Reg64 BN = abi_param5; + Xbyak::Reg64 src0 = rdx; + Xbyak::Reg64 src1 = r8; + Xbyak::Reg64 dst = r9; + Xbyak::Reg64 prefetch_dst = r10; + Xbyak::Reg64 BN = r11; + + mov(src0, ptr[abi_param1 + offsetof(CallArgs, src0)]); + mov(src1, ptr[abi_param1 + offsetof(CallArgs, src1)]); + mov(dst, ptr[abi_param1 + offsetof(CallArgs, dst)]); + mov(prefetch_dst, ptr[abi_param1 + offsetof(CallArgs, prefetch_dst)]); + mov(BN, ptr[abi_param1 + offsetof(CallArgs, num_cols)]); + Xbyak::Reg64 loop_i = rax; Xbyak::Label loop_begin; @@ -635,10 +663,16 @@ void ReduceAdd2bh::generate() { ret(); } else { - Xbyak::Reg64 src0 = abi_param1; - Xbyak::Reg64 dst = abi_param2; - Xbyak::Reg64 prefetch_dst = abi_param3; - Xbyak::Reg64 BN = abi_param4; + Xbyak::Reg64 src0 = rdx; + Xbyak::Reg64 dst = r9; + Xbyak::Reg64 prefetch_dst = r10; + Xbyak::Reg64 BN = r11; + + mov(src0, ptr[abi_param1 + offsetof(CallArgs, src0)]); + mov(dst, ptr[abi_param1 + offsetof(CallArgs, dst)]); + mov(prefetch_dst, ptr[abi_param1 + offsetof(CallArgs, prefetch_dst)]); + mov(BN, ptr[abi_param1 + offsetof(CallArgs, num_cols)]); + Xbyak::Reg64 loop_i = rax; Xbyak::Label loop_begin; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp index 1fcd1f000b3d92..438d84c16b3ece 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp @@ -526,26 +526,41 @@ class ReduceAdd2bh : public dnnl::impl::cpu::x64::jit_generator { void generate() override; + struct CallArgs { + float* src0; + float* src1; + int16_t * dst; + int16_t * prefetch_dst; + int64_t num_cols; + }; // add two float input eltwise and convert to bf16 : ConvertFP32toBF16(src0 + src1) void call(float* src0, float* src1, size_t src_stride, void* pf16_dst, size_t dst_stride, int num_rows, int num_cols) { - auto* dst = reinterpret_cast(pf16_dst); - for (int m = 0; m < num_rows; m++, src0 += src_stride, src1 += src_stride, dst += dst_stride) { + CallArgs args; + args.src0 = src0; + args.src1 = src1; + args.dst = reinterpret_cast(pf16_dst); + args.num_cols = num_cols; + for (int m = 0; m < num_rows; m++, args.src0 += src_stride, args.src1 += src_stride, args.dst += dst_stride) { // the prefetch distance is increased to ensure by the time store happens // prefetch has done and no HW prefetcher is triggered - auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst); - (*this)(src0, src1, dst, prefetch_dst, num_cols); + args.prefetch_dst = (m + 2 < num_rows) ? (args.dst + 2 * dst_stride) : (args.dst); + + (*this)(&args); } } // convert tensor to bf16: ConvertFP32toBF16(src0) void call(float* src0, size_t src_stride, void* pf16_dst, size_t dst_stride, int num_rows, int num_cols) { - auto* dst = reinterpret_cast(pf16_dst); - for (int m = 0; m < num_rows; m++, src0 += src_stride, dst += dst_stride) { + CallArgs args; + args.src0 = src0; + args.dst = reinterpret_cast(pf16_dst); + args.num_cols = num_cols; + for (int m = 0; m < num_rows; m++, args.src0 += src_stride, args.dst += dst_stride) { // the prefetch distance is increased to ensure by the time store happens // prefetch has done and no HW prefetcher is triggered - auto* prefetch_dst = (m + 2 < num_rows) ? (dst + 2 * dst_stride) : (dst); - (*this)(src0, dst, prefetch_dst, num_cols); + args.prefetch_dst = (m + 2 < num_rows) ? (args.dst + 2 * dst_stride) : (args.dst); + (*this)(&args); } } }; diff --git a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp index ed7bb00fd575db..163608999fada6 100644 --- a/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp +++ b/src/plugins/intel_cpu/src/nodes/qkv_proj.cpp @@ -291,12 +291,7 @@ struct QKVProjection::Executor : public QKVProjection::ExecutorBase { asym); } // compress accumulation result into target - for (int mi = 0; mi < BM; mi++, src += stride_src, dst += stride_dst) { - // the prefetch distance is increased to ensure by the time store happens - // prefetch has done and no HW prefetcher is triggered - auto* prefetch_dst = (mi + 2 < BM) ? (dst + 2 * stride_dst) : (dst); - jit_cvt(src, dst, prefetch_dst, work.BN); - } + jit_cvt.call(src, stride_src, dst, stride_dst, BM, work.BN); } }); m += BM;