Skip to content

Commit

Permalink
[Hardware][CPU] Update torch 2.5 (#9911)
Browse files Browse the repository at this point in the history
Signed-off-by: jiang1.li <[email protected]>
  • Loading branch information
bigPYJ1151 authored Nov 7, 2024
1 parent 29862b8 commit a4b3e0c
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ docker exec cpu-test bash -c "
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li

RUN echo 'ulimit -c 0' >> ~/.bashrc

RUN pip install intel_extension_for_pytorch==2.4.0
RUN pip install intel_extension_for_pytorch==2.5.0

WORKDIR /workspace

Expand Down
1 change: 1 addition & 0 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
#
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-mf16c"
"-DVLLM_CPU_EXTENSION")

execute_process(COMMAND cat /proc/cpuinfo
Expand Down
10 changes: 10 additions & 0 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ struct KernelVecType<float> {
using v_load_vec_type = vec_op::FP32Vec16;
};

template <>
struct KernelVecType<c10::Half> {
using q_load_vec_type = vec_op::FP16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP16Vec16;
};

#ifdef __AVX512BF16__
template <>
struct KernelVecType<c10::BFloat16> {
Expand Down
78 changes: 45 additions & 33 deletions csrc/cpu/cpu_types_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ static_assert(false, "AVX2 must be supported for the current implementation.");

namespace vec_op {

// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
Expand Down Expand Up @@ -50,37 +50,37 @@ template <typename T> struct Vec {
struct FP32Vec8;
struct FP32Vec16;

#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;

__m128h reg;
__m128i reg;

explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
explicit FP16Vec8(const void *ptr)
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}

explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
explicit FP16Vec8(const FP32Vec8 &);

explicit FP16Vec8(__m128h data) : reg(data) {}
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
};

FP16Vec8 operator*(const FP16Vec8 &b) const {
return FP16Vec8(_mm_mul_ph(reg, b.reg));
}
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;

FP16Vec8 operator+(const FP16Vec8 &b) const {
return FP16Vec8(_mm_add_ph(reg, b.reg));
}
__m256i reg;

FP16Vec8 operator-(const FP16Vec8 &b) const {
return FP16Vec8(_mm_sub_ph(reg, b.reg));
}
explicit FP16Vec16(const void *ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}

FP16Vec8 operator/(const FP16Vec8 &b) const {
return FP16Vec8(_mm_div_ph(reg, b.reg));
}
explicit FP16Vec16(const FP32Vec16 &);

void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }

void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm256_mask_storeu_epi16(ptr, mask, reg);
}
};
#endif

struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
Expand Down Expand Up @@ -202,9 +202,7 @@ struct FP32Vec8 : public Vec<FP32Vec8> {

explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}

#ifdef __AVX512FP16__
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
#endif
explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {}

explicit FP32Vec8(const BF16Vec8 &v)
: reg(_mm256_castsi256_ps(
Expand Down Expand Up @@ -323,6 +321,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
: reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}

explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {}

explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}

explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}

explicit FP32Vec16(const INT32Vec16 &v)
Expand Down Expand Up @@ -534,24 +536,34 @@ template <typename T> using vec_t = typename VecType<T>::vec_type;

template <> struct VecType<float> { using vec_type = FP32Vec8; };

#ifdef __AVX512FP16__
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
#endif
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };

template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };

template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }

#ifdef __AVX512FP16__
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<_Float16 *>(ptr) = v;
}
#endif

inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b;
}

template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<unsigned short *>(ptr) =
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
}

inline FP16Vec8::FP16Vec8(const FP32Vec8 &v)
: reg(_mm256_cvtps_ph(v.reg,
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}

#ifdef __AVX512F__
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
: reg(_mm512_cvtps_ph(v.reg,
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
#else
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
: reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
#endif

#ifdef __AVX512BF16__
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpu/dnnl_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define DNNL_HELPER_HPP

#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>

#include "oneapi/dnnl/dnnl.hpp"

Expand Down Expand Up @@ -32,6 +33,11 @@ struct DNNLType<c10::BFloat16> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
};

template <>
struct DNNLType<c10::Half> {
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
};

template <typename T>
constexpr inline dnnl::memory::data_type get_dnnl_type() {
return DNNLType<std::decay_t<T>>::type;
Expand Down
7 changes: 7 additions & 0 deletions csrc/cpu/quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ struct KernelVecType<c10::BFloat16> {
using cvt_vec_type = vec_op::FP32Vec16;
};

template <>
struct KernelVecType<c10::Half> {
using load_vec_type = vec_op::FP16Vec16;
using azp_adj_load_vec_type = vec_op::INT32Vec16;
using cvt_vec_type = vec_op::FP32Vec16;
};

#ifdef __AVX512F__
template <bool AZP, typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
Expand Down
6 changes: 2 additions & 4 deletions docs/source/getting_started/cpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
Installation with CPU
========================

vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. vLLM CPU backend supports the following vLLM features:
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:

- Tensor Parallel (``-tp = N``)
- Quantization (``INT8 W8A8, AWQ``)

.. note::
FP16 data type and more advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.

Table of contents:

Expand Down Expand Up @@ -72,8 +72,6 @@ Build from source
$ VLLM_TARGET_DEVICE=cpu python setup.py install
.. note::
- BF16 is the default data type in the current CPU backend (that means the backend will cast FP16 to BF16), and is compatible will all CPUs with AVX512 ISA support.

- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.

- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
Expand Down
2 changes: 1 addition & 1 deletion requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
-r requirements-common.txt

# Dependencies for x86_64 CPUs
torch == 2.4.0+cpu; platform_machine != "ppc64le"
torch == 2.5.1+cpu; platform_machine != "ppc64le"
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
3 changes: 1 addition & 2 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
"openbmb/MiniCPM3-4B",
]

# TODO: remove this after CPU float16 support ready
target_dtype = "float" if current_platform.is_cpu() else "half"
target_dtype = "half"


@pytest.mark.parametrize("model", MODELS)
Expand Down
5 changes: 0 additions & 5 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from functools import partial
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union

import torch

import vllm.envs as envs
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
Expand Down Expand Up @@ -316,9 +314,6 @@ async def check_health_async(self) -> None:


def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
config.dtype = torch.bfloat16
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if not config.enforce_eager:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/ipex_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_name(cls) -> str:

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16]
return [torch.bfloat16, torch.float16]

@classmethod
def get_min_capability(cls) -> int:
Expand Down

0 comments on commit a4b3e0c

Please sign in to comment.