From a4b3e0c1e999d214c6355b16a1c68250e6c030e2 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Thu, 7 Nov 2024 12:43:08 +0800 Subject: [PATCH] [Hardware][CPU] Update torch 2.5 (#9911) Signed-off-by: jiang1.li --- .buildkite/run-cpu-test.sh | 2 +- Dockerfile.cpu | 2 +- cmake/cpu_extension.cmake | 1 + csrc/cpu/attention.cpp | 10 +++ csrc/cpu/cpu_types_x86.hpp | 78 +++++++++++-------- csrc/cpu/dnnl_helper.hpp | 6 ++ csrc/cpu/quant.cpp | 7 ++ .../getting_started/cpu-installation.rst | 6 +- requirements-cpu.txt | 2 +- .../decoder_only/language/test_models.py | 3 +- vllm/executor/cpu_executor.py | 5 -- .../layers/quantization/ipex_quant.py | 2 +- 12 files changed, 76 insertions(+), 48 deletions(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index c331a9c49c0d0..2dbeee8562971 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -46,7 +46,7 @@ docker exec cpu-test bash -c " docker exec cpu-test bash -c " export VLLM_CPU_KVCACHE_SPACE=10 export VLLM_CPU_OMP_THREADS_BIND=48-92 - python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & + python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 python3 benchmarks/benchmark_serving.py \ --backend vllm \ diff --git a/Dockerfile.cpu b/Dockerfile.cpu index f1a21d6bd13fc..287b4958da4e5 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -22,7 +22,7 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li RUN echo 'ulimit -c 0' >> ~/.bashrc -RUN pip install intel_extension_for_pytorch==2.4.0 +RUN pip install intel_extension_for_pytorch==2.5.0 WORKDIR /workspace diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 7237d246ddf55..776a0bb11ae64 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -18,6 +18,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc") # list(APPEND CXX_COMPILE_FLAGS "-fopenmp" + "-mf16c" "-DVLLM_CPU_EXTENSION") execute_process(COMMAND cat /proc/cpuinfo diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index abb4e3bea14bb..e3953c7c45719 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -22,6 +22,16 @@ struct KernelVecType { using v_load_vec_type = vec_op::FP32Vec16; }; +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::FP16Vec8; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::FP16Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP16Vec16; +}; + #ifdef __AVX512BF16__ template <> struct KernelVecType { diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index a325153b470cc..12d5757b495be 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -11,10 +11,10 @@ static_assert(false, "AVX2 must be supported for the current implementation."); namespace vec_op { -// FIXME: FP16 is not fully supported in Torch-CPU #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) @@ -50,37 +50,37 @@ template struct Vec { struct FP32Vec8; struct FP32Vec16; -#ifdef __AVX512FP16__ struct FP16Vec8 : public Vec { constexpr static int VEC_ELEM_NUM = 8; - __m128h reg; + __m128i reg; - explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + explicit FP16Vec8(const void *ptr) + : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {} - explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + explicit FP16Vec8(const FP32Vec8 &); - explicit FP16Vec8(__m128h data) : reg(data) {} + void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; } +}; - FP16Vec8 operator*(const FP16Vec8 &b) const { - return FP16Vec8(_mm_mul_ph(reg, b.reg)); - } +struct FP16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; - FP16Vec8 operator+(const FP16Vec8 &b) const { - return FP16Vec8(_mm_add_ph(reg, b.reg)); - } + __m256i reg; - FP16Vec8 operator-(const FP16Vec8 &b) const { - return FP16Vec8(_mm_sub_ph(reg, b.reg)); - } + explicit FP16Vec16(const void *ptr) + : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {} - FP16Vec8 operator/(const FP16Vec8 &b) const { - return FP16Vec8(_mm_div_ph(reg, b.reg)); - } + explicit FP16Vec16(const FP32Vec16 &); - void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } + void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } + + void save(void* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } }; -#endif struct BF16Vec8 : public Vec { constexpr static int VEC_ELEM_NUM = 8; @@ -202,9 +202,7 @@ struct FP32Vec8 : public Vec { explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {} -#ifdef __AVX512FP16__ - explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} -#endif + explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {} explicit FP32Vec8(const BF16Vec8 &v) : reg(_mm256_castsi256_ps( @@ -323,6 +321,10 @@ struct FP32Vec16 : public Vec { : reg(_mm512_castsi512_ps( _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {} + explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {} + + explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const INT32Vec16 &v) @@ -534,24 +536,34 @@ template using vec_t = typename VecType::vec_type; template <> struct VecType { using vec_type = FP32Vec8; }; -#ifdef __AVX512FP16__ -template <> struct VecType { using vec_type = FP16Vec16; }; -#endif +template <> struct VecType { using vec_type = FP16Vec8; }; template <> struct VecType { using vec_type = BF16Vec8; }; template void storeFP32(float v, T *ptr) { *ptr = v; } -#ifdef __AVX512FP16__ -template <> inline void storeFP32(float v, c10::Half *ptr) { - *reinterpret_cast<_Float16 *>(ptr) = v; -} -#endif - inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { acc = acc + a * b; } +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast(ptr) = + _cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); +} + +inline FP16Vec8::FP16Vec8(const FP32Vec8 &v) + : reg(_mm256_cvtps_ph(v.reg, + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} + +#ifdef __AVX512F__ +inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) + : reg(_mm512_cvtps_ph(v.reg, + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {} +#else +inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) + : reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {} +#endif + #ifdef __AVX512BF16__ template <> inline void storeFP32(float v, c10::BFloat16 *ptr) { *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp index 024ad4ae43da8..8b5011dc065f0 100644 --- a/csrc/cpu/dnnl_helper.hpp +++ b/csrc/cpu/dnnl_helper.hpp @@ -2,6 +2,7 @@ #define DNNL_HELPER_HPP #include +#include #include "oneapi/dnnl/dnnl.hpp" @@ -32,6 +33,11 @@ struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; }; +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; +}; + template constexpr inline dnnl::memory::data_type get_dnnl_type() { return DNNLType>::type; diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index b493fd793818a..f42fa2361a2db 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -23,6 +23,13 @@ struct KernelVecType { using cvt_vec_type = vec_op::FP32Vec16; }; +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP16Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + #ifdef __AVX512F__ template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index d12aeebbbc184..69530fd778c55 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -3,13 +3,13 @@ Installation with CPU ======================== -vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. vLLM CPU backend supports the following vLLM features: +vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features: - Tensor Parallel (``-tp = N``) - Quantization (``INT8 W8A8, AWQ``) .. note:: - FP16 data type and more advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon. + More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon. Table of contents: @@ -72,8 +72,6 @@ Build from source $ VLLM_TARGET_DEVICE=cpu python setup.py install .. note:: - - BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support. - - AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16. - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building. diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 27ca8ca5dbc58..749b03a0603d8 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,5 +2,5 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.4.0+cpu; platform_machine != "ppc64le" +torch == 2.5.1+cpu; platform_machine != "ppc64le" torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 05117666f8c3f..d705909c24bf8 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -32,8 +32,7 @@ "openbmb/MiniCPM3-4B", ] -# TODO: remove this after CPU float16 support ready -target_dtype = "float" if current_platform.is_cpu() else "half" +target_dtype = "half" @pytest.mark.parametrize("model", MODELS) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index ab3ebb4e43d18..4ceb5a837dd7f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -2,8 +2,6 @@ from functools import partial from typing import Any, Awaitable, List, Optional, Set, Tuple, Union -import torch - import vllm.envs as envs from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -316,9 +314,6 @@ async def check_health_async(self) -> None: def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: - if config.dtype == torch.float16: - logger.warning("float16 is not supported on CPU, casting to bfloat16.") - config.dtype = torch.bfloat16 # Reminder: Please update docs/source/serving/compatibility_matrix.rst # If the feature combo become valid if not config.enforce_eager: diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 43f4502f7455c..330c2ad195d78 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -54,7 +54,7 @@ def get_name(cls) -> str: @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.bfloat16] + return [torch.bfloat16, torch.float16] @classmethod def get_min_capability(cls) -> int: