From f1bf96abe8816a86ec75cf69c63668ccc9238f8d Mon Sep 17 00:00:00 2001 From: Sanket Kale Date: Wed, 20 Nov 2024 11:12:22 +0530 Subject: [PATCH] Changed flag name and modified compile flag declaration Signed-off-by: Sanket Kale --- cmake/cpu_extension.cmake | 21 ++++----------------- csrc/cpu/attention.cpp | 2 +- csrc/cpu/cpu_types_arm.hpp | 23 +++++++++-------------- 3 files changed, 14 insertions(+), 32 deletions(-) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 248c04d4b3dd2..68f7ca1af05ad 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -16,20 +16,8 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc") # # Check the compile flags # -<<<<<<< HEAD -if (CMAKE_SYSTEM_PROCESSOR STREQUAL "ppc64le") - list(APPEND CXX_COMPILE_FLAGS - "-fopenmp" - "-DVLLM_CPU_EXTENSION") -else() - list(APPEND CXX_COMPILE_FLAGS - "-fopenmp" - "-mf16c" - "-DVLLM_CPU_EXTENSION") -endif() -======= -if (NOT CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") +if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") list(APPEND CXX_COMPILE_FLAGS "-mf16c" ) @@ -37,7 +25,6 @@ endif() list(APPEND CXX_COMPILE_FLAGS "-fopenmp" "-DVLLM_CPU_EXTENSION") ->>>>>>> eca86e66 (Rebased and resolved merge conflicts) execute_process(COMMAND cat /proc/cpuinfo RESULT_VARIABLE CPUINFO_RET @@ -72,7 +59,7 @@ find_isa(${CPUINFO} "avx512f" AVX512_FOUND) find_isa(${CPUINFO} "POWER10" POWER10_FOUND) find_isa(${CPUINFO} "POWER9" POWER9_FOUND) find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support -find_isa(${CPUINFO} "bf16" BF16_FOUND) # Check for BF16 support +find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support if (AVX512_FOUND AND NOT AVX512_DISABLED) list(APPEND CXX_COMPILE_FLAGS @@ -107,10 +94,10 @@ elseif (POWER9_FOUND OR POWER10_FOUND) elseif (ASIMD_FOUND) message(STATUS "ARMv8 or later architecture detected") - if(BF16_FOUND) + if(ARM_BF16_FOUND) message(STATUS "BF16 extension detected") set(MARCH_FLAGS "-march=armv8.2-a+bf16+dotprod+fp16") - add_compile_definitions(BF16_SUPPORT) + add_compile_definitions(ARM_BF16_SUPPORT) else() message(WARNING "BF16 functionality is not available") set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16") diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index a2ce4c21b6a50..e21832ba7582f 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -52,7 +52,7 @@ struct KernelVecType { }; #else #ifdef __aarch64__ - #ifndef BF16_SUPPORT + #ifndef ARM_BF16_SUPPORT // pass #else template <> diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp index 1d29040690702..73e0f8cb2e0fb 100644 --- a/csrc/cpu/cpu_types_arm.hpp +++ b/csrc/cpu/cpu_types_arm.hpp @@ -1,11 +1,10 @@ #include #include - #include namespace vec_op { -#ifdef BF16_SUPPORT +#ifdef ARM_BF16_SUPPORT #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ @@ -103,7 +102,7 @@ struct FP16Vec16 : public Vec { }; -#ifdef BF16_SUPPORT +#ifdef ARM_BF16_SUPPORT struct BF16Vec8 : public Vec { constexpr static int VEC_ELEM_NUM = 8; @@ -209,7 +208,7 @@ struct FP32Vec8 : public Vec { explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {}; - #ifdef BF16_SUPPORT + #ifdef ARM_BF16_SUPPORT explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {}; @@ -333,7 +332,7 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {} - #ifdef BF16_SUPPORT + #ifdef ARM_BF16_SUPPORT explicit FP32Vec16(bfloat16x8x2_t v) : reg({ vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]), @@ -349,7 +348,7 @@ struct FP32Vec16 : public Vec { reg.val[3] = data.reg; }; - #ifdef BF16_SUPPORT + #ifdef ARM_BF16_SUPPORT explicit FP32Vec16(const BF16Vec16 &v) : reg({ vcvtq_low_f32_bf16(v.reg.val[0]), vcvtq_high_f32_bf16(v.reg.val[0]), @@ -367,10 +366,6 @@ struct FP32Vec16 : public Vec { reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); }; - // #ifdef BF16_SUPPORT - // explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} - // #endif - FP32Vec16 operator+(const FP32Vec16 &b) const { return FP32Vec16(float32x4x4_t({ vaddq_f32(reg.val[0], b.reg.val[0]), @@ -443,7 +438,7 @@ template <> struct VecType { using vec_type = FP32Vec8; }; template <> struct VecType { using vec_type = FP16Vec8; }; -#ifdef BF16_SUPPORT +#ifdef ARM_BF16_SUPPORT template <> struct VecType { using vec_type = BF16Vec8; }; #endif @@ -478,7 +473,7 @@ inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a.reg.val[3], b.reg.val[3]); }; -#ifdef BF16_SUPPORT +#ifdef ARM_BF16_SUPPORT inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0])); @@ -498,7 +493,7 @@ inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { }; #endif -#ifdef BF16_SUPPORT +#ifdef ARM_BF16_SUPPORT inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {}; inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({ @@ -511,7 +506,7 @@ inline void prefetch(const void *addr) { __builtin_prefetch(addr, 0, 1); }; -#ifdef BF16_SUPPORT +#ifdef ARM_BF16_SUPPORT template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { *reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v);