From 38a1674abbba38344543170cb552e88e7f619167 Mon Sep 17 00:00:00 2001 From: Chip Kerchner <49959681+ChipKerchner@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:53:04 -0400 Subject: [PATCH] Support CPU inference with VSX PowerPC ISA (#5652) --- Dockerfile.ppc64le | 22 ++ cmake/cpu_extension.cmake | 11 +- csrc/cpu/cpu_types.hpp | 514 +----------------------------------- csrc/cpu/cpu_types_vsx.hpp | 491 +++++++++++++++++++++++++++++++++++ csrc/cpu/cpu_types_x86.hpp | 515 +++++++++++++++++++++++++++++++++++++ csrc/ops.h | 1 + requirements-cpu.txt | 6 +- 7 files changed, 1049 insertions(+), 511 deletions(-) create mode 100644 Dockerfile.ppc64le create mode 100644 csrc/cpu/cpu_types_vsx.hpp create mode 100644 csrc/cpu/cpu_types_x86.hpp diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le new file mode 100644 index 0000000000000..d4e4c483cada8 --- /dev/null +++ b/Dockerfile.ppc64le @@ -0,0 +1,22 @@ +FROM mambaorg/micromamba +ARG MAMBA_DOCKERFILE_ACTIVATE=1 +USER root + +RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +# Some packages in requirements-cpu are installed here +# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba +# Currently these may not be available for venv or pip directly +RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +# These packages will be in rocketce eventually +RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing + +RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install + +WORKDIR /vllm-workspace +ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 511e443f78403..690559ee265e9 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -46,6 +46,8 @@ is_avx512_disabled(AVX512_DISABLED) find_isa(${CPUINFO} "avx2" AVX2_FOUND) find_isa(${CPUINFO} "avx512f" AVX512_FOUND) +find_isa(${CPUINFO} "POWER10" POWER10_FOUND) +find_isa(${CPUINFO} "POWER9" POWER9_FOUND) if (AVX512_FOUND AND NOT AVX512_DISABLED) list(APPEND CXX_COMPILE_FLAGS @@ -68,8 +70,15 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") message(WARNING "vLLM CPU backend using AVX2 ISA") +elseif (POWER9_FOUND OR POWER10_FOUND) + message(STATUS "PowerPC detected") + # Check for PowerPC VSX support + list(APPEND CXX_COMPILE_FLAGS + "-mvsx" + "-mcpu=native" + "-mtune=native") else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 ISA support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index d7621aaae81c9..0213be09105ed 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -2,514 +2,14 @@ #ifndef CPU_TYPES_HPP #define CPU_TYPES_HPP -#include -#include - -#ifndef __AVX2__ -static_assert(false, "AVX2 must be supported for the current implementation."); -#endif - -namespace vec_op { - -// FIXME: FP16 is not fully supported in Torch-CPU -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) - -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) - -#ifndef CPU_OP_GUARD -#define CPU_KERNEL_GUARD_IN(NAME) -#define CPU_KERNEL_GUARD_OUT(NAME) -#else -#define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; -#endif - -#define FORCE_INLINE __attribute__((always_inline)) inline - -namespace { -template -constexpr void unroll_loop_item(std::integer_sequence, F &&f) { - (f(std::integral_constant{}), ...); -} -}; // namespace - -template >> -constexpr void unroll_loop(F &&f) { - unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); -} - -template struct Vec { - constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } -}; - -struct FP32Vec8; -struct FP32Vec16; - -#ifdef __AVX512FP16__ -struct FP16Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - - __m128h reg; - - explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} - - explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} - - explicit FP16Vec8(__m128h data) : reg(data) {} - - FP16Vec8 operator*(const FP16Vec8 &b) const { - return FP16Vec8(_mm_mul_ph(reg, b.reg)); - } - - FP16Vec8 operator+(const FP16Vec8 &b) const { - return FP16Vec8(_mm_add_ph(reg, b.reg)); - } - - FP16Vec8 operator-(const FP16Vec8 &b) const { - return FP16Vec8(_mm_sub_ph(reg, b.reg)); - } - - FP16Vec8 operator/(const FP16Vec8 &b) const { - return FP16Vec8(_mm_div_ph(reg, b.reg)); - } - - void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } -}; -#endif - -struct BF16Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - - __m128i reg; - - explicit BF16Vec8(const void *ptr) - : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} - - explicit BF16Vec8(const FP32Vec8 &); - - void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } -}; - -struct BF16Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - - __m256i reg; - - explicit BF16Vec16(const void *ptr) - : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} - - explicit BF16Vec16(const FP32Vec16 &); - - void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } -}; - -#ifdef __AVX512F__ -struct BF16Vec32 : public Vec { - constexpr static int VEC_ELEM_NUM = 32; - - __m512i reg; - - explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} - - explicit BF16Vec32(__m512i data) : reg(data) {} - - explicit BF16Vec32(BF16Vec8 &vec8_data) - : reg((__m512i)_mm512_inserti32x4( - _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( - (__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1), - (__m128i)vec8_data.reg, 2), - (__m128i)vec8_data.reg, 3)) {} - - void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } -}; -#else -struct BF16Vec32 : public Vec { - constexpr static int VEC_ELEM_NUM = 32; - - __m256i reg_low; - __m256i reg_high; - - explicit BF16Vec32(const void *ptr) - : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), - reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} - - explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), - reg_high(high) {} - - explicit BF16Vec32(BF16Vec8 &vec8_data) - : reg_low((__m256i)_mm256_inserti32x4( - _mm256_castsi128_si256((__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1)), - reg_high((__m256i)_mm256_inserti32x4( - _mm256_castsi128_si256((__m128i)vec8_data.reg), - (__m128i)vec8_data.reg, 1)) {} - - void save(void *ptr) const { - *reinterpret_cast<__m256i *>(ptr) = reg_low; - *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; - } -}; -#endif - -struct FP32Vec4 : public Vec { - constexpr static int VEC_ELEM_NUM = 4; - union AliasReg { - __m128 reg; - float values[VEC_ELEM_NUM]; - }; - - __m128 reg; - - explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} - - explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} - - explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} - - explicit FP32Vec4(__m128 data) : reg(data) {} - - explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} -}; - -struct FP32Vec8 : public Vec { - constexpr static int VEC_ELEM_NUM = 8; - union AliasReg { - __m256 reg; - float values[VEC_ELEM_NUM]; - }; - - __m256 reg; - - explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} - - explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} - - explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} - - explicit FP32Vec8(__m256 data) : reg(data) {} - - explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} - -#ifdef __AVX512FP16__ - explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} -#endif - - explicit FP32Vec8(const BF16Vec8 &v) - : reg(_mm256_castsi256_ps( - _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} - - float reduce_sum() const { - AliasReg ar; - ar.reg = reg; - float result = 0; - unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); - - return result; - } - - FP32Vec8 exp() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), - expf(ar.values[5]), expf(ar.values[4]), - expf(ar.values[3]), expf(ar.values[2]), - expf(ar.values[1]), expf(ar.values[0]))); - } - - FP32Vec8 tanh() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), - tanhf(ar.values[5]), tanhf(ar.values[4]), - tanhf(ar.values[3]), tanhf(ar.values[2]), - tanhf(ar.values[1]), tanhf(ar.values[0]))); - } - - FP32Vec8 er() const { - AliasReg ar; - ar.reg = reg; - return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), - erf(ar.values[5]), erf(ar.values[4]), - erf(ar.values[3]), erf(ar.values[2]), - erf(ar.values[1]), erf(ar.values[0]))); - } - - FP32Vec8 operator*(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_mul_ps(reg, b.reg)); - } - - FP32Vec8 operator+(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_add_ps(reg, b.reg)); - } - - FP32Vec8 operator-(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_sub_ps(reg, b.reg)); - } - - FP32Vec8 operator/(const FP32Vec8 &b) const { - return FP32Vec8(_mm256_div_ps(reg, b.reg)); - } - - void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } -}; - -#ifdef __AVX512F__ -struct FP32Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - union AliasReg { - __m512 reg; - float values[VEC_ELEM_NUM]; - }; - - __m512 reg; - - explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} - - explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} - - explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} - - explicit FP32Vec16(__m512 data) : reg(data) {} - - explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} - - explicit FP32Vec16(const FP32Vec4 &data) - : reg((__m512)_mm512_inserti32x4( - _mm512_inserti32x4( - _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), - (__m128i)data.reg, 1), - (__m128i)data.reg, 2), - (__m128i)data.reg, 3)) {} - - explicit FP32Vec16(const FP32Vec8 &data) - : reg((__m512)_mm512_inserti32x8( - _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} - - explicit FP32Vec16(const BF16Vec16 &v) - : reg(_mm512_castsi512_ps( - _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} - - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} - - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_mul_ps(reg, b.reg)); - } - - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_add_ps(reg, b.reg)); - } - - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_sub_ps(reg, b.reg)); - } - - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(_mm512_div_ps(reg, b.reg)); - } - - float reduce_sum() const { return _mm512_reduce_add_ps(reg); } - - template float reduce_sub_sum(int idx) { - static_assert(VEC_ELEM_NUM % group_size == 0); - constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); - __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); - return _mm512_mask_reduce_add_ps(mask, reg); - } - - void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } -}; +#if defined(__x86_64__) + //x86 implementation + #include "cpu_types_x86.hpp" +#elif defined(__POWER9_VECTOR__) + //ppc implementation + #include "cpu_types_vsx.hpp" #else -struct FP32Vec16 : public Vec { - constexpr static int VEC_ELEM_NUM = 16; - - union AliasReg { - __m256 reg; - float values[8]; - }; - - __m256 reg_low; - __m256 reg_high; - - explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), - reg_high(_mm256_set1_ps(v)) {} - - explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), - reg_high(_mm256_set1_ps(0.0)) {} - - explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), - reg_high(_mm256_loadu_ps(ptr + 8)) {} - - explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} - - explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), - reg_high(data.reg_high) {} - - explicit FP32Vec16(const FP32Vec4 &data) - : reg_low((__m256)_mm256_inserti128_si256( - _mm256_castsi128_si256((__m128i)data.reg), - (__m128i)data.reg, 1)), - reg_high((__m256)_mm256_inserti128_si256( - _mm256_castsi128_si256((__m128i)data.reg), - (__m128i)data.reg, 1)) {} - - explicit FP32Vec16(const FP32Vec8 &data) - : reg_low(data.reg), reg_high(data.reg) {} - - explicit FP32Vec16(const BF16Vec16 &v) { - __m128i low = _mm256_extractf128_si256(v.reg, 0); - __m128i high = _mm256_extractf128_si256(v.reg, 1); - - __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low); - __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high); - - __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2); - __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2); - - reg_low = _mm256_castsi256_ps(v_low_shifted); - reg_high = _mm256_castsi256_ps(v_high_shifted); - } - - explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} - - FP32Vec16 operator*(const FP32Vec16 &b) const { - return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), - _mm256_mul_ps(reg_high, b.reg_high)); - } - - FP32Vec16 operator+(const FP32Vec16 &b) const { - return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), - _mm256_add_ps(reg_high, b.reg_high)); - } - - FP32Vec16 operator-(const FP32Vec16 &b) const { - return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), - _mm256_sub_ps(reg_high, b.reg_high)); - } - - FP32Vec16 operator/(const FP32Vec16 &b) const { - return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), - _mm256_div_ps(reg_high, b.reg_high)); - } - - float reduce_sum() const { - FP32Vec8 low = FP32Vec8(reg_low); - FP32Vec8 high = FP32Vec8(reg_high); - return low.reduce_sum() + high.reduce_sum(); - } - - template float reduce_sub_sum(int idx) { - float sum = 0.0; - static_assert(VEC_ELEM_NUM % group_size == 0); - constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); - uint32_t mask = base_mask << (idx * group_size); - - AliasReg ar; - - auto func = [&sum, &mask, &ar](int i) { - int flag = mask & 0x1; - mask = mask >> 1; - if (flag != 0) sum += ar.values[i]; - }; - - ar.reg = reg_low; - unroll_loop(func); - - ar.reg = reg_high; - unroll_loop(func); - - return sum; - } - - void save(float *ptr) const { - _mm256_storeu_ps(ptr, reg_low); - _mm256_storeu_ps(ptr + 8, reg_high); - } -}; -#endif - -template struct VecType { using vec_type = void; }; - -template using vec_t = typename VecType::vec_type; - -template <> struct VecType { using vec_type = FP32Vec8; }; - -#ifdef __AVX512FP16__ -template <> struct VecType { using vec_type = FP16Vec16; }; + #warning "unsupported vLLM cpu implementation" #endif -template <> struct VecType { using vec_type = BF16Vec8; }; - -template void storeFP32(float v, T *ptr) { *ptr = v; } - -#ifdef __AVX512FP16__ -template <> inline void storeFP32(float v, c10::Half *ptr) { - *reinterpret_cast<_Float16 *>(ptr) = v; -} -#endif - -inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { - acc = acc + a * b; -} - -#ifdef __AVX512BF16__ -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); -} - -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) - : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) - : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} - -inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { - acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); -} -#else -template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { - c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = - reinterpret_cast(&v); - *ptr = *(v_ptr + 1); -} - -#ifdef __AVX512F__ -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) - : reg(_mm256_cvtepi32_epi16( - _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) - : reg(_mm512_cvtepi32_epi16( - _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} -#else -namespace{ -__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { - __m256i ai = _mm256_castps_si256(a); - ai = _mm256_srli_epi32(ai, 16); - ai = _mm256_packus_epi32(ai, ai); - ai = _mm256_permute4x64_epi64(ai, 0b00111001); - return _mm256_extracti128_si256(ai, 0); -} -} - -inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) - : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} - -inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { - BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); - BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); - reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); -} -#endif // __AVX512F__ -#endif // __AVX512BF16__ - -inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } - -}; // namespace vec_op - #endif diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp new file mode 100644 index 0000000000000..b50bdadc5713d --- /dev/null +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -0,0 +1,491 @@ + +#ifndef CPU_TYPES_VSX_HPP +#define CPU_TYPES_VSX_HPP + +#include +#include +#include + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +typedef struct ss16x8x2_t { + __vector signed short val[2]; +} ss16x8x2_t; + +typedef struct ss16x8x4_t { + __vector signed short val[4]; +} ss16x8x4_t; + +typedef struct f32x4x2_t { + __vector float val[2]; +} f32x4x2_t; + +typedef struct f32x4x4_t { + __vector float val[4]; +} f32x4x4_t; + +struct FP32Vec8; +struct FP32Vec16; + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __vector signed short reg; + + explicit BF16Vec8(const void *ptr) + : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + ss16x8x2_t reg; + + explicit BF16Vec16(const void *ptr) { + // Load 256 bits in two parts + reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr); + reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr); + } + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { + // Save 256 bits in two parts + vec_xst(reg.val[0], 0, (signed short *)ptr); + vec_xst(reg.val[1], 16, (signed short *)ptr); + } +}; + +const static __vector signed short zero = vec_splats((signed short)0); + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + ss16x8x4_t reg; + explicit BF16Vec32(const void *ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} + + explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ + vec8_data.reg, + vec8_data.reg, + vec8_data.reg, + vec8_data.reg + }) {} + + void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __vector float reg; + float values[VEC_ELEM_NUM]; + }; + + __vector float reg; + + explicit FP32Vec4(float v) : reg(vec_splats(v)) {} + + explicit FP32Vec4() : reg(vec_splats(0.0f)) {} + + explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {} + + explicit FP32Vec4(__vector float data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + f32x4x2_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x2_t reg; + + explicit FP32Vec8(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + } + + explicit FP32Vec8() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + } + + explicit FP32Vec8(const float *ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + } + + explicit FP32Vec8(f32x4x2_t data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + } + + explicit FP32Vec8(const BF16Vec8 &v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::exp(ar.values[0]); + ret.val[0][1] = std::exp(ar.values[1]); + ret.val[0][2] = std::exp(ar.values[2]); + ret.val[0][3] = std::exp(ar.values[3]); + ret.val[1][0] = std::exp(ar.values[4]); + ret.val[1][1] = std::exp(ar.values[5]); + ret.val[1][2] = std::exp(ar.values[6]); + ret.val[1][3] = std::exp(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 tanh() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::tanh(ar.values[0]); + ret.val[0][1] = std::tanh(ar.values[1]); + ret.val[0][2] = std::tanh(ar.values[2]); + ret.val[0][3] = std::tanh(ar.values[3]); + ret.val[1][0] = std::tanh(ar.values[4]); + ret.val[1][1] = std::tanh(ar.values[5]); + ret.val[1][2] = std::tanh(ar.values[6]); + ret.val[1][3] = std::tanh(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 er() const { + // TODO: Vectorize this + AliasReg ar; + ar.reg = reg; + f32x4x4_t ret; + ret.val[0][0] = std::erf(ar.values[0]); + ret.val[0][1] = std::erf(ar.values[1]); + ret.val[0][2] = std::erf(ar.values[2]); + ret.val[0][3] = std::erf(ar.values[3]); + ret.val[1][0] = std::erf(ar.values[4]); + ret.val[1][1] = std::erf(ar.values[5]); + ret.val[1][2] = std::erf(ar.values[6]); + ret.val[1][3] = std::erf(ar.values[7]); + return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])}); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])}); + } + + void save(float *ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + f32x4x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + f32x4x4_t reg; + + explicit FP32Vec16(float v) { + reg.val[0] = vec_splats(v); + reg.val[1] = vec_splats(v); + reg.val[2] = vec_splats(v); + reg.val[3] = vec_splats(v); + } + + explicit FP32Vec16() { + reg.val[0] = vec_splats(0.0f); + reg.val[1] = vec_splats(0.0f); + reg.val[2] = vec_splats(0.0f); + reg.val[3] = vec_splats(0.0f); + } + + explicit FP32Vec16(const float *ptr) { + reg.val[0] = vec_xl(0, ptr); + reg.val[1] = vec_xl(16, ptr); + reg.val[2] = vec_xl(32, ptr); + reg.val[3] = vec_xl(48, ptr); + } + + explicit FP32Vec16(f32x4x4_t data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[2]; + reg.val[3] = data.reg.val[3]; + } + + explicit FP32Vec16(const FP32Vec4 &data) { + reg.val[0] = data.reg; + reg.val[1] = data.reg; + reg.val[2] = data.reg; + reg.val[3] = data.reg; + } + + explicit FP32Vec16(const FP32Vec8 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; + } + + explicit FP32Vec16(const BF16Vec16 &v) { + reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); + reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); + reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); + reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); + } + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_mul(reg.val[0], b.reg.val[0]), + vec_mul(reg.val[1], b.reg.val[1]), + vec_mul(reg.val[2], b.reg.val[2]), + vec_mul(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_add(reg.val[0], b.reg.val[0]), + vec_add(reg.val[1], b.reg.val[1]), + vec_add(reg.val[2], b.reg.val[2]), + vec_add(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_sub(reg.val[0], b.reg.val[0]), + vec_sub(reg.val[1], b.reg.val[1]), + vec_sub(reg.val[2], b.reg.val[2]), + vec_sub(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(f32x4x4_t({ + vec_div(reg.val[0], b.reg.val[0]), + vec_div(reg.val[1], b.reg.val[1]), + vec_div(reg.val[2], b.reg.val[2]), + vec_div(reg.val[3], b.reg.val[3])})); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float result = 0; + const int start = idx * group_size; + unroll_loop( + [&result, &start, ar](int i) { result += ar.values[start + i]; }); + + return result; + } + + void save(float *ptr) const { + vec_xst(reg.val[0], 0, ptr); + vec_xst(reg.val[1], 16, ptr); + vec_xst(reg.val[2], 32, ptr); + vec_xst(reg.val[3], 48, ptr); + } +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifndef __VEC_CLASS_FP_NAN +#define __VEC_CLASS_FP_NAN (1 << 6) +#endif + +const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 }; +#ifndef _ARCH_PWR10 +const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff }; +const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 }; +const static __vector unsigned int sh16 = { 16, 16, 16, 16 }; +const static __vector unsigned int one = { 1, 1, 1, 1 }; +#endif + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) { +#ifdef _ARCH_PWR10 + __vector signed short ret[2]; + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + reg = vec_perm(ret[0], ret[1], omask); +#elif defined(_ARCH_PWR9) + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int lsb0 = vec_sr(inp0, sh16); + __vector unsigned int lsb1 = vec_sr(inp1, sh16); + lsb0 = vec_and(lsb0, one); + lsb1 = vec_and(lsb1, one); + __vector unsigned int rnd0 = vec_add(lsb0, bias); + __vector unsigned int rnd1 = vec_add(lsb1, bias); + inp0 = vec_add(inp0, rnd0); + inp1 = vec_add(inp1, rnd1); + __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + inp0 = vec_sel(inp0, nan, sel0); + inp1 = vec_sel(inp1, nan, sel1); + inp0 = vec_sr(inp0, sh16); + inp1 = vec_sr(inp1, sh16); + reg = (__vector signed short)vec_perm(inp0, inp1, omask); +#endif +} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { +#ifdef _ARCH_PWR10 + __vector signed short ret[4]; + ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]); + ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]); + ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]); + ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]); + reg.val[0] = vec_perm(ret[0], ret[1], omask); + reg.val[1] = vec_perm(ret[2], ret[3], omask); +#elif defined(_ARCH_PWR9) + __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); + __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]); + __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]); + __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]); + __vector unsigned int lsb0 = vec_sr(inp0, sh16); + __vector unsigned int lsb1 = vec_sr(inp1, sh16); + __vector unsigned int lsb2 = vec_sr(inp2, sh16); + __vector unsigned int lsb3 = vec_sr(inp3, sh16); + lsb0 = vec_and(lsb0, one); + lsb1 = vec_and(lsb1, one); + lsb2 = vec_and(lsb2, one); + lsb3 = vec_and(lsb3, one); + __vector unsigned int rnd0 = vec_add(lsb0, bias); + __vector unsigned int rnd1 = vec_add(lsb1, bias); + __vector unsigned int rnd2 = vec_add(lsb2, bias); + __vector unsigned int rnd3 = vec_add(lsb3, bias); + inp0 = vec_add(inp0, rnd0); + inp1 = vec_add(inp1, rnd1); + inp2 = vec_add(inp2, rnd2); + inp3 = vec_add(inp3, rnd3); + __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); + __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); + __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN); + __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN); + inp0 = vec_sel(inp0, nan, sel0); + inp1 = vec_sel(inp1, nan, sel1); + inp2 = vec_sel(inp2, nan, sel2); + inp3 = vec_sel(inp3, nan, sel3); + inp0 = vec_sr(inp0, sh16); + inp1 = vec_sr(inp1, sh16); + inp2 = vec_sr(inp2, sh16); + inp3 = vec_sr(inp3, sh16); + reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask); + reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask); +#endif +} + +inline void prefetch(const void *addr) { + __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); +} + +}; // namespace vec_op + +#endif diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp new file mode 100644 index 0000000000000..f50620a5287d4 --- /dev/null +++ b/csrc/cpu/cpu_types_x86.hpp @@ -0,0 +1,515 @@ + +#ifndef CPU_TYPES_X86_HPP +#define CPU_TYPES_X86_HPP + +#include +#include + +#ifndef __AVX2__ +static_assert(false, "AVX2 must be supported for the current implementation."); +#endif + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +#ifdef __AVX512FP16__ +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128h reg; + + explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + + explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + + explicit FP16Vec8(__m128h data) : reg(data) {} + + FP16Vec8 operator*(const FP16Vec8 &b) const { + return FP16Vec8(_mm_mul_ph(reg, b.reg)); + } + + FP16Vec8 operator+(const FP16Vec8 &b) const { + return FP16Vec8(_mm_add_ph(reg, b.reg)); + } + + FP16Vec8 operator-(const FP16Vec8 &b) const { + return FP16Vec8(_mm_sub_ph(reg, b.reg)); + } + + FP16Vec8 operator/(const FP16Vec8 &b) const { + return FP16Vec8(_mm_div_ph(reg, b.reg)); + } + + void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +}; +#endif + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128i reg; + + explicit BF16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} + + explicit BF16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + __m256i reg; + + explicit BF16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} + + explicit BF16Vec16(const FP32Vec16 &); + + void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } +}; + +#ifdef __AVX512F__ +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m512i reg; + + explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {} + + explicit BF16Vec32(__m512i data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512i)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; } +}; +#else +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m256i reg_low; + __m256i reg_high; + + explicit BF16Vec32(const void *ptr) + : reg_low(_mm256_loadu_si256((__m256i const *)ptr)), + reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {} + + explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low), + reg_high(high) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg_low((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)), + reg_high((__m256i)_mm256_inserti32x4( + _mm256_castsi128_si256((__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1)) {} + + void save(void *ptr) const { + *reinterpret_cast<__m256i *>(ptr) = reg_low; + *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high; + } +}; +#endif + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + union AliasReg { + __m128 reg; + float values[VEC_ELEM_NUM]; + }; + + __m128 reg; + + explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {} + + explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {} + + explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {} + + explicit FP32Vec4(__m128 data) : reg(data) {} + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {} +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} + +#ifdef __AVX512FP16__ + explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +#endif + + explicit FP32Vec8(const BF16Vec8 &v) + : reg(_mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {} + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float result = 0; + unroll_loop([&result, &ar](int i) { result += ar.values[i]; }); + + return result; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]), + tanhf(ar.values[5]), tanhf(ar.values[4]), + tanhf(ar.values[3]), tanhf(ar.values[2]), + tanhf(ar.values[1]), tanhf(ar.values[0]))); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]), + erf(ar.values[5]), erf(ar.values[4]), + erf(ar.values[3]), erf(ar.values[2]), + erf(ar.values[1]), erf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +}; + +#ifdef __AVX512F__ +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg((__m512)_mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg), + (__m128i)data.reg, 1), + (__m128i)data.reg, 2), + (__m128i)data.reg, 3)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg((__m512)_mm512_inserti32x8( + _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {} + + explicit FP32Vec16(const BF16Vec16 &v) + : reg(_mm512_castsi512_ps( + _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_sub_ps(reg, b.reg)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_div_ps(reg, b.reg)); + } + + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size)); + return _mm512_mask_reduce_add_ps(mask, reg); + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +}; +#else +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + union AliasReg { + __m256 reg; + float values[8]; + }; + + __m256 reg_low; + __m256 reg_high; + + explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)), + reg_high(_mm256_set1_ps(v)) {} + + explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)), + reg_high(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)), + reg_high(_mm256_loadu_ps(ptr + 8)) {} + + explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {} + + explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low), + reg_high(data.reg_high) {} + + explicit FP32Vec16(const FP32Vec4 &data) + : reg_low((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), + (__m128i)data.reg, 1)), + reg_high((__m256)_mm256_inserti128_si256( + _mm256_castsi128_si256((__m128i)data.reg), + (__m128i)data.reg, 1)) {} + + explicit FP32Vec16(const FP32Vec8 &data) + : reg_low(data.reg), reg_high(data.reg) {} + + explicit FP32Vec16(const BF16Vec16 &v) { + __m128i low = _mm256_extractf128_si256(v.reg, 0); + __m128i high = _mm256_extractf128_si256(v.reg, 1); + + __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low); + __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high); + + __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2); + __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2); + + reg_low = _mm256_castsi256_ps(v_low_shifted); + reg_high = _mm256_castsi256_ps(v_high_shifted); + } + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low), + _mm256_mul_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low), + _mm256_add_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low), + _mm256_sub_ps(reg_high, b.reg_high)); + } + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low), + _mm256_div_ps(reg_high, b.reg_high)); + } + + float reduce_sum() const { + FP32Vec8 low = FP32Vec8(reg_low); + FP32Vec8 high = FP32Vec8(reg_high); + return low.reduce_sum() + high.reduce_sum(); + } + + template float reduce_sub_sum(int idx) { + float sum = 0.0; + static_assert(VEC_ELEM_NUM % group_size == 0); + constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); + uint32_t mask = base_mask << (idx * group_size); + + AliasReg ar; + + auto func = [&sum, &mask, &ar](int i) { + int flag = mask & 0x1; + mask = mask >> 1; + if (flag != 0) sum += ar.values[i]; + }; + + ar.reg = reg_low; + unroll_loop(func); + + ar.reg = reg_high; + unroll_loop(func); + + return sum; + } + + void save(float *ptr) const { + _mm256_storeu_ps(ptr, reg_low); + _mm256_storeu_ps(ptr + 8, reg_high); + } +}; +#endif + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +#ifdef __AVX512FP16__ +template <> struct VecType { using vec_type = FP16Vec16; }; +#endif + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +#ifdef __AVX512FP16__ +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast<_Float16 *>(ptr) = v; +} +#endif + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + acc = acc + a * b; +} + +#ifdef __AVX512BF16__ +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {} + +inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { + acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg); +} +#else +template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { + c10::BFloat16 __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +#ifdef __AVX512F__ +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtepi32_epi16( + _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtepi32_epi16( + _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {} +#else +namespace{ +__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) { + __m256i ai = _mm256_castps_si256(a); + ai = _mm256_srli_epi32(ai, 16); + ai = _mm256_packus_epi32(ai, ai); + ai = _mm256_permute4x64_epi64(ai, 0b00111001); + return _mm256_extracti128_si256(ai, 0); +} +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) + : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {} + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) { + BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low)); + BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high)); + reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1); +} +#endif // __AVX512F__ +#endif // __AVX512BF16__ + +inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); } + +}; // namespace vec_op + +#endif diff --git a/csrc/ops.h b/csrc/ops.h index ae04150eaf756..8a92afdc81a9b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,5 +1,6 @@ #pragma once +#include #include void paged_attention_v1( diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 21acee91d7b57..754070df21c0a 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,6 +2,6 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.3.1+cpu -torchvision == 0.18.1+cpu # required for the image processor of phi3v, this must be updated alongside torch -triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file +torch == 2.3.1+cpu; platform_machine != "ppc64le" +torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch +triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.