diff --git a/CMakeLists.txt b/CMakeLists.txt index 7717365ba..463bb2117 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,11 +59,7 @@ option(NE_AVX512_VBMI "neural_engine: enable AVX512-VBMI" option(NE_AVX512_VNNI "neural_engine: enable AVX512-VNNI" OFF) option(NE_FMA "neural_engine: enable FMA" ON) option(NE_AMX "neural_engine: enable AMX" OFF) - -# in MSVC F16C is implied with AVX2/AVX512 -if (NOT MSVC) - option(NE_F16C "neural_engine: enable F16C" ON) -endif() +option(NE_F16C "neural_engine: enable F16C" ON) # 3rd party libs option(NE_ONEDNN "neural_engine: use oneDNN" ON) @@ -93,6 +89,8 @@ if (NE_GELU_VEC) endif() option(NE_PYTHON_API "neural_engine: use python api" OFF) option(NE_SIMD_VEC_DOT_F16 "neural_engine: enable vec_dot_fp16 SIMD optimization" ON) +option(BUILD_SHARED_LIBS "If build as shared libs" ON) + if (NE_SIMD_VEC_DOT_F16) add_compile_definitions(NE_SIMD_VEC_DOT_F16) endif() @@ -103,7 +101,6 @@ endif() if (MSVC) add_compile_definitions(_CRT_SECURE_NO_WARNINGS NOMINMAX) - if (BUILD_SHARED_LIBS) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) endif() diff --git a/bestla/jblas/jit_blas_parallel.h b/bestla/jblas/jit_blas_parallel.h index e9a63e142..5e6b3b650 100644 --- a/bestla/jblas/jit_blas_parallel.h +++ b/bestla/jblas/jit_blas_parallel.h @@ -204,7 +204,7 @@ class SchedulerBase : public Scheduler2D { mL2Use += static_cast(mBlock[1]) * mBlock[2] * mEleSize[1]; mL2Use += static_cast(mStep[0]) * mBlock[2] * mEleSize[0]; } - const float DensityThres = 32; + const float DensityThres = 16; static size_t constexpr ReservedSize = 32ULL * 1024ULL; virtual float calculate_score() { @@ -364,7 +364,7 @@ class SchedulerKBlock : public Scheduler2D { mL2Use += static_cast(mBlock[1]) * mBlock[2] * mEleSize[1]; mL2Use += static_cast(mStep[0]) * mBlock[2] * mEleSize[0]; } - const float DensityThres = 32; + const float DensityThres = 16; float calculate_score() { int tmpnstep = mThdSize[1] < _GemmCore_T::PREFERRED_N ? mThdSize[1] : _GemmCore_T::PREFERRED_N; @@ -489,13 +489,14 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { this->mL2Use += static_cast(blks) * (this->mBlock[1] + this->mStep[0]) * (sizeof(float) + sizeof(int8_t) + sizeof(float)); // scale+zp+reduce assert(this->mL2Use <= this->mL2Size - ReservedSize); - assert(this->mBlock[0]>0); - assert(this->mBlock[1]>0); - assert(this->mBlock[2]>0); + assert(this->mBlock[0] > 0); + assert(this->mBlock[1] > 0); + assert(this->mBlock[2] > 0); + assert(this->mBlock[2] % _GemmCore_T::KTILE == 0); } protected: - const float DensityThres = 32; + const float DensityThres = 16; static size_t constexpr ReservedSize = 32ULL * 1024ULL; void cache_blocking_compute() override { @@ -529,6 +530,11 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { (this->mStep[0] * this->mEleSize[0] + float(CorSize * (this->mStep[0] + this->mBlock[1])) / this->mKBlock + this->mBlock[1] * this->mEleSize[1])); + if (rawk < this->mKBlock) { + rawk = static_cast((valid_total - this->mBlock[0] * this->mBlock[1] * this->mEleSize[2] - + 1 * CorSize * (this->mStep[0] + this->mBlock[1])) / + (this->mStep[0] * this->mEleSize[0] + this->mBlock[1] * this->mEleSize[1])); + } rawk = std::min(rawk, this->mSizePadded[2]); this->mBlock[2] = utils::padto_le(rawk, this->mStep[2]); if (this->mBlock[2] > this->mKBlock) { @@ -569,9 +575,6 @@ class SchedulerKBlockS : public SchedulerBase<_GemmCore_T> { this->mBlock[2] = static_cast(getMaxK(this->mBlock[1])); this->mBlock[2] = utils::padto_le(this->mBlock[2], this->mStep[2]); this->mBlock[2] = std::min(mKBlock, this->mBlock[2]); - auto tmp = utils::updiv(mKBlock, this->mBlock[2]); - while (mKBlock % tmp != 0) tmp++; // TODO(Yu) optimize - this->mBlock[2] = utils::downdiv(mKBlock, tmp); } } diff --git a/bestla/jblas/kernel_avx2.h b/bestla/jblas/kernel_avx2.h index 5555e47a0..1e9fdf287 100644 --- a/bestla/jblas/kernel_avx2.h +++ b/bestla/jblas/kernel_avx2.h @@ -412,10 +412,14 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int for (; j < align_col; j += 8) quant(); for (; j < col; j++) { auto fp_v = ref::f8_to_fp32(srcptr[i * ld_src + j], src_f8_type); - if constexpr (std::is_same_v<_S_T, utils::f8>) { - dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x); - } else if constexpr (std::is_same_v<_S_T, float>) { - dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW]; + if constexpr (WITH_SCALE) { + if constexpr (std::is_same_v<_S_T, utils::f8>) { + dstptr[i * ld_dst + j] = fp_v * std::pow(2, sptr[j / _PACK_ROW].x); + } else if constexpr (std::is_same_v<_S_T, float>) { + dstptr[i * ld_dst + j] = fp_v * sptr[j / _PACK_ROW]; + } + } else { + dstptr[i * ld_dst + j] = fp_v; } } } @@ -636,6 +640,14 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1( vzps[iv] = _mm256_cvtepi8_epi32(tmp); } } + auto rowre = row - irow; + int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow; + for (; irow < rowpad4; irow += UnrollRow) { + for (int iter16 = 0; iter16 < Loop16; iter16++) + pad_bit4_16(tmpbuf + iter16 * 16, reinterpret_cast(srcptr + irow * ld_src / 2 + 8 * iter16)); + for (int iterr = 0; iterr < UnrollRow; iterr++) + dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * _NCOL, vscales, vzps); + } for (; irow < row; irow++) { if constexpr (_NCOL == 24) { pad_bit4_16(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2)); diff --git a/bestla/jblas/kernel_avx512f.h b/bestla/jblas/kernel_avx512f.h index d0ad1aadd..8da3aafd6 100644 --- a/bestla/jblas/kernel_avx512f.h +++ b/bestla/jblas/kernel_avx512f.h @@ -321,13 +321,28 @@ static inline JBLAS_CODE decompress_kblock_bit4_packrow1(utils::bit4x2* srcptr, vzps[iv] = _mm512_cvtepi8_epi32(tmp); } } - } - for (; irow < row; irow++) { - pad_bit4(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48); - if constexpr (_IS_SYM) { - dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr); - } else { - dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); + auto rowre = row - irow; + int rowpad4 = utils::padto_le(rowre, UnrollRow) + irow; + for (; irow < rowpad4; irow += UnrollRow) { + for (int iter64 = 0; iter64 < Loop64; iter64++) { + pad_bit4(tmpbuf + iter64 * 64, reinterpret_cast(srcptr + irow * ld_src / 2 + 32 * iter64), zmm_mask, + LoadMask64); + } + for (int iterr = 0; iterr < UnrollRow; iterr++) { + if constexpr (_IS_SYM) { + dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, nullptr); + } else { + dequantize(dstptr + (irow + iterr) * ld_dst, tmpbuf + iterr * ColTile, vscales, vzps); + } + } + } + for (; irow < row; irow++) { + pad_bit4(tmpbuf, reinterpret_cast(srcptr + irow * ld_src / 2), zmm_mask, LoadMask48); + if constexpr (_IS_SYM) { + dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, nullptr); + } else { + dequantize(dstptr + irow * ld_dst, tmpbuf, vscales, vzps); + } } } return JblasSuccess; @@ -565,7 +580,7 @@ inline JBLAS_CODE decompress_kblock_f8_fp(utils::f8* srcptr, _DST_T* dstptr, int auto quant = [&](__mmask16 mask) { __m128i f8_src; auto sign_revert = - _mm512_cvtepi8_epi32(_mm_mask_loadu_epi8(f8_src, mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j))); + _mm512_cvtepi8_epi32(_mm_maskz_loadu_epi8(mask, reinterpret_cast<__m128i*>(srcptr + i * ld_src + j))); auto e_revert = sign_revert; auto mantissa_revert = sign_revert; sign_revert = _mm512_slli_epi32(sign_revert, 24); @@ -888,10 +903,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src zmm2 = _mm512_add_ps(zmm2, zmm_zp); zmm3 = _mm512_add_ps(zmm3, zmm_zp); } else { - mask4 = _mm512_cmplt_ps_mask(zmm0, zmm_v0); - mask5 = _mm512_cmplt_ps_mask(zmm1, zmm_v0); - mask6 = _mm512_cmplt_ps_mask(zmm2, zmm_v0); - mask7 = _mm512_cmplt_ps_mask(zmm3, zmm_v0); + mask4 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1); + mask5 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 1); + mask6 = _mm512_cmp_ps_mask(zmm2, zmm_v0, 1); + mask7 = _mm512_cmp_ps_mask(zmm3, zmm_v0, 1); zmm0 = _mm512_abs_ps(zmm0); zmm1 = _mm512_abs_ps(zmm1); @@ -908,10 +923,10 @@ inline void f32_f4_quantize_4x16(const float* srcptr, int8_t* dstptr, int ld_src zmm5 = _mm512_sub_ps(zmm1, sub_v); zmm6 = _mm512_sub_ps(zmm2, sub_v); zmm7 = _mm512_sub_ps(zmm3, sub_v); - mask0 = _mm512_cmple_ps_mask(zmm4, zmm_v0); - mask1 = _mm512_cmple_ps_mask(zmm5, zmm_v0); - mask2 = _mm512_cmple_ps_mask(zmm6, zmm_v0); - mask3 = _mm512_cmple_ps_mask(zmm7, zmm_v0); + mask0 = _mm512_cmp_ps_mask(zmm4, zmm_v0, 2); + mask1 = _mm512_cmp_ps_mask(zmm5, zmm_v0, 2); + mask2 = _mm512_cmp_ps_mask(zmm6, zmm_v0, 2); + mask3 = _mm512_cmp_ps_mask(zmm7, zmm_v0, 2); xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast(broadcast_f4_v + i * 16))); xmm1 = _mm_mask_blend_epi8(mask1, xmm1, _mm_loadu_si128(reinterpret_cast(broadcast_f4_v + i * 16))); xmm2 = _mm_mask_blend_epi8(mask2, xmm2, _mm_loadu_si128(reinterpret_cast(broadcast_f4_v + i * 16))); @@ -949,7 +964,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src auto zp = _mm512_set1_ps(0.8480964004993439f); zmm0 = _mm512_add_ps(zmm0, zp); } else { - mask1 = _mm512_cmplt_ps_mask(zmm0, zmm_v0); + mask1 = _mm512_cmp_ps_mask(zmm0, zmm_v0, 1); zmm0 = _mm512_abs_ps(zmm0); } constexpr int loop_num = F4_T == JBLAS_DTYPE::F4_NF4 ? 16 : 8; @@ -959,7 +974,7 @@ inline void f32_f4_quantize_1x16(const float* srcptr, int8_t* dstptr, int ld_src if constexpr (F4_T == JBLAS_DTYPE::F4_BNB) sub_v = _mm512_set1_ps(F4_BNB_quant_sub_helper[i]); if constexpr (F4_T == JBLAS_DTYPE::F4_E2M1) sub_v = _mm512_set1_ps(F4_E2M1_quant_sub_helper[i]); zmm1 = _mm512_sub_ps(zmm0, sub_v); - mask0 = _mm512_cmple_ps_mask(zmm1, zmm_v0); + mask0 = _mm512_cmp_ps_mask(zmm1, zmm_v0, 2); xmm0 = _mm_mask_blend_epi8(mask0, xmm0, _mm_loadu_si128(reinterpret_cast(broadcast_f4_v + i * 16))); zmm0 = _mm512_mask_add_ps(zmm0, mask0, zmm0, avoid_double_cmp); } diff --git a/bestla/jblas/kernel_ref.h b/bestla/jblas/kernel_ref.h index 1b0710f36..1e0ddccda 100644 --- a/bestla/jblas/kernel_ref.h +++ b/bestla/jblas/kernel_ref.h @@ -230,25 +230,47 @@ inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { dstptr[7] = tmp; } +inline void convert_s4_s8_8_lowbits(int8_t* dstptr, int8_t* srcptr) { + auto src32 = *reinterpret_cast(srcptr); + auto tmp = static_cast(src32 & 0xf); + dstptr[0] = static_cast(tmp); + tmp = static_cast(src32 & 0xf0) >> 4; + dstptr[1] = static_cast(tmp); + tmp = static_cast((src32 & 0xf00) >> 8); + dstptr[2] = static_cast(tmp); + tmp = static_cast((src32 & 0xf000) >> 12); + dstptr[3] = static_cast(tmp); + tmp = static_cast((src32 & 0xf0000) >> 16); + dstptr[4] = static_cast(tmp); + tmp = static_cast((src32 & 0xf00000) >> 20); + dstptr[5] = static_cast(tmp); + tmp = static_cast((src32 & 0xf000000) >> 24); + dstptr[6] = static_cast(tmp); + tmp = static_cast((src32 & 0xf0000000) >> 28); + dstptr[7] = static_cast(tmp); +} + template <> inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { - auto src32 = *reinterpret_cast(srcptr); - auto tmp = static_cast(src32 & 0xf); - dstptr[0] = tmp - 8; - tmp = static_cast(src32 & 0xf0) >> 4; - dstptr[1] = tmp - 8; - tmp = static_cast((src32 & 0xf00) >> 8); - dstptr[2] = tmp - 8; - tmp = static_cast((src32 & 0xf000) >> 12); - dstptr[3] = tmp - 8; - tmp = static_cast((src32 & 0xf0000) >> 16); - dstptr[4] = tmp - 8; - tmp = static_cast((src32 & 0xf00000) >> 20); - dstptr[5] = tmp - 8; - tmp = static_cast((src32 & 0xf000000) >> 24); - dstptr[6] = tmp - 8; - tmp = static_cast((src32 & 0xf0000000) >> 28); - dstptr[7] = tmp - 8; + convert_s4_s8_8_lowbits(dstptr, srcptr); + for (size_t i = 0; i < 8; i++) { + dstptr[i] -= 8; + } +} + +template <> +inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { + convert_s4_s8_8_lowbits(dstptr, srcptr); +} + +template <> +inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { + convert_s4_s8_8_lowbits(dstptr, srcptr); +} + +template <> +inline void convert_s4_s8_8(int8_t* dstptr, int8_t* srcptr) { + convert_s4_s8_8_lowbits(dstptr, srcptr); } template diff --git a/neural_speed/cmake/Common.cmake b/neural_speed/cmake/Common.cmake index e10412aa6..d3e266ce6 100644 --- a/neural_speed/cmake/Common.cmake +++ b/neural_speed/cmake/Common.cmake @@ -36,9 +36,25 @@ function(add_executable_w_warning TARGET) warning_check(${TARGET}) endfunction() -function(add_library_w_warning TARGET) - add_library(${TARGET} STATIC ${ARGN}) +function(add_library_w_warning_ TARGET) + add_library(${TARGET} ${ARGN}) set_target_properties(${TARGET} PROPERTIES C_STANDARD 11 C_STANDARD_REQUIRED ON C_EXTENSIONS OFF) set_target_properties(${TARGET} PROPERTIES CXX_STANDARD 11 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF) warning_check(${TARGET}) endfunction() + +function(add_library_w_warning TARGET) + add_library_w_warning_(${TARGET} STATIC ${ARGN}) +endfunction() + +function(add_shared_library_w_warning TARGET) + add_library_w_warning_(${TARGET} SHARED ${ARGN}) +endfunction() + +function(add_shareable_library_w_warning TARGET) + if (BUILD_SHARED_LIBS) + add_library_w_warning_(${TARGET} SHARED ${ARGN}) + else() + add_library_w_warning_(${TARGET} STATIC ${ARGN}) + endif() +endfunction() diff --git a/neural_speed/cmake/ISA.cmake b/neural_speed/cmake/ISA.cmake index 63570940d..ab477da1f 100644 --- a/neural_speed/cmake/ISA.cmake +++ b/neural_speed/cmake/ISA.cmake @@ -13,6 +13,9 @@ # limitations under the License. if (MSVC) + if(NE_F16C) + add_compile_definitions(__F16C__) + endif() if (NE_AVX512) add_compile_options($<$:/arch:AVX512>) add_compile_options($<$:/arch:AVX512>) diff --git a/neural_speed/core/CMakeLists.txt b/neural_speed/core/CMakeLists.txt index b77e8b56d..bcf34a9ca 100644 --- a/neural_speed/core/CMakeLists.txt +++ b/neural_speed/core/CMakeLists.txt @@ -16,7 +16,7 @@ find_package(Threads REQUIRED) file(GLOB layers_srcs "layers/*.cpp") set(sources ne_layers.c ${layers_srcs}) -add_library_w_warning(ne_layers "${sources}") +add_shareable_library_w_warning(ne_layers "${sources}") target_include_directories(ne_layers PUBLIC .) target_compile_features(ne_layers PUBLIC c_std_11) # don't bump diff --git a/neural_speed/scripts/convert_mistral.py b/neural_speed/scripts/convert_mistral.py index a93b0aa70..7bd6cfca0 100644 --- a/neural_speed/scripts/convert_mistral.py +++ b/neural_speed/scripts/convert_mistral.py @@ -855,8 +855,8 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) -SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, 'BOOL': DT_BOOL} - +SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {'F16': DT_F16, 'F32': DT_F32, 'I32': DT_I32, 'BOOL': DT_BOOL, + 'BF16': DT_BF16} def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: header_size, = struct.unpack('