Skip to content

Commit

Permalink
Changed flag name and modified compile flag declaration
Browse files Browse the repository at this point in the history
Signed-off-by: Sanket Kale <[email protected]>
  • Loading branch information
Sanket Kale committed Nov 20, 2024
1 parent dea09ba commit f1bf96a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 32 deletions.
21 changes: 4 additions & 17 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,15 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
#
# Check the compile flags
#
<<<<<<< HEAD
if (CMAKE_SYSTEM_PROCESSOR STREQUAL "ppc64le")
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")
else()
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-mf16c"
"-DVLLM_CPU_EXTENSION")
endif()
=======

if (NOT CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
list(APPEND CXX_COMPILE_FLAGS
"-mf16c"
)
endif()
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")
>>>>>>> eca86e66 (Rebased and resolved merge conflicts)

execute_process(COMMAND cat /proc/cpuinfo
RESULT_VARIABLE CPUINFO_RET
Expand Down Expand Up @@ -72,7 +59,7 @@ find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
find_isa(${CPUINFO} "bf16" BF16_FOUND) # Check for BF16 support
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support

if (AVX512_FOUND AND NOT AVX512_DISABLED)
list(APPEND CXX_COMPILE_FLAGS
Expand Down Expand Up @@ -107,10 +94,10 @@ elseif (POWER9_FOUND OR POWER10_FOUND)

elseif (ASIMD_FOUND)
message(STATUS "ARMv8 or later architecture detected")
if(BF16_FOUND)
if(ARM_BF16_FOUND)
message(STATUS "BF16 extension detected")
set(MARCH_FLAGS "-march=armv8.2-a+bf16+dotprod+fp16")
add_compile_definitions(BF16_SUPPORT)
add_compile_definitions(ARM_BF16_SUPPORT)
else()
message(WARNING "BF16 functionality is not available")
set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16")
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct KernelVecType<c10::BFloat16> {
};
#else
#ifdef __aarch64__
#ifndef BF16_SUPPORT
#ifndef ARM_BF16_SUPPORT
// pass
#else
template <>
Expand Down
23 changes: 9 additions & 14 deletions csrc/cpu/cpu_types_arm.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#include <arm_neon.h>
#include <torch/all.h>

#include <cmath>

namespace vec_op {

#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
Expand Down Expand Up @@ -103,7 +102,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
};


#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;

Expand Down Expand Up @@ -209,7 +208,7 @@ struct FP32Vec8 : public Vec<FP32Vec8> {

explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};

#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT

explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};

Expand Down Expand Up @@ -333,7 +332,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {

explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {}

#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(bfloat16x8x2_t v) : reg({
vcvtq_low_f32_bf16(v.val[0]),
vcvtq_high_f32_bf16(v.val[0]),
Expand All @@ -349,7 +348,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg.val[3] = data.reg;
};

#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(const BF16Vec16 &v) : reg({
vcvtq_low_f32_bf16(v.reg.val[0]),
vcvtq_high_f32_bf16(v.reg.val[0]),
Expand All @@ -367,10 +366,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
};

// #ifdef BF16_SUPPORT
// explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
// #endif

FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(float32x4x4_t({
vaddq_f32(reg.val[0], b.reg.val[0]),
Expand Down Expand Up @@ -443,7 +438,7 @@ template <> struct VecType<float> { using vec_type = FP32Vec8; };

template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };

#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
#endif

Expand Down Expand Up @@ -478,7 +473,7 @@ inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a.reg.val[3], b.reg.val[3]);
};

#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {

float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
Expand All @@ -498,7 +493,7 @@ inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
};
#endif

#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {};

inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({
Expand All @@ -511,7 +506,7 @@ inline void prefetch(const void *addr) {
__builtin_prefetch(addr, 0, 1);
};

#ifdef BF16_SUPPORT
#ifdef ARM_BF16_SUPPORT
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v);
Expand Down

0 comments on commit f1bf96a

Please sign in to comment.