Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[BesTLA] Improve RTN quantization accuracy of int4 and int3 (#172)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Wang,Zhe <[email protected]>
  • Loading branch information
luoyu-intel and zhewang1-intc authored Mar 18, 2024
1 parent 8728765 commit a90aea7
Show file tree
Hide file tree
Showing 20 changed files with 578 additions and 944 deletions.
1 change: 0 additions & 1 deletion bestla/bestla/bestla.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ enum class BTLA_DTYPE : uint32_t {
U8 = EleBits8 | TypeInt | SubType1,
S3_CLIP = EleBits3 | TypeInt,
S4_CLIP = EleBits4 | TypeInt,
S4_FULLRANGE = EleBits4 | TypeInt | SubType1,
F4_E2M1 = EleBits4 | TypeFloat,
F4_BNB = EleBits4 | TypeFloat | SubType1,
F4_NF4 = EleBits4 | TypeFloat | SubType2,
Expand Down
6 changes: 3 additions & 3 deletions bestla/bestla/bestla_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ class CpuDevice {
mHybrid = false;
}
}
numcores = P_core.size() + E_core.size();
numthreads = P_core.size() + E_core.size() + SMT_core.size();
numcores = static_cast<int>(P_core.size() + E_core.size());
numthreads = static_cast<int>(P_core.size() + E_core.size() + SMT_core.size());

{
// set PE
Expand Down Expand Up @@ -515,7 +515,7 @@ class CpuRuntime {
} else {
mL1Cache_P = mL1Cache;
mL2Cache_P = mL2Cache;
P_core_num = _cd->getPcoreNum();
P_core_num = static_cast<int>(_cd->getPcoreNum());
E_core_num = thread - P_core_num;
}
mL1Cache_E = _cd->getL1CacheSize_E();
Expand Down
36 changes: 19 additions & 17 deletions bestla/bestla/bestla_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,25 @@ class StdThreading : public IThreading {
}

inline void sync(int tidx, int idx = 0) override {
flag[idx].fetch_sub(1);
if (cr->mHybrid) {
Timer_T tm;
tm.start();
while (true) {
if (flag[idx].load() == 0)
break;
else
_mm_pause();
}
thread_time[tidx] -= int(tm.stop());
} else {
while (true) {
if (flag[idx].load() == 0)
break;
else
_mm_pause();
if (mThreadNum > 1) {
flag[idx].fetch_sub(1);
if (cr->mHybrid) {
Timer_T tm;
tm.start();
while (true) {
if (flag[idx].load() == 0)
break;
else
_mm_pause();
}
thread_time[tidx] -= int(tm.stop());
} else {
while (true) {
if (flag[idx].load() == 0)
break;
else
_mm_pause();
}
}
}
}
Expand Down
42 changes: 7 additions & 35 deletions bestla/bestla/bestla_prologue_b.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ class WeightKBlockNInteger {
auto wptr = _param.packedW;
if (wptr->mDType == BTLA_DTYPE::S8) {
return getQ8Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S4_CLIP || wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
} else if (wptr->mDType == BTLA_DTYPE::S4_CLIP) {
return getQ4Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S3_CLIP) {
return getQ3Weight(dstptr, dststep, k_size, n_size, k_offset, n_offset, _param, tmpcache, cachesize);
Expand Down Expand Up @@ -674,7 +674,7 @@ class WeightKBlockNInteger {
kernel::wrapper::Dq8GetScale::template forward<ISA_T>(
aptr + internal_k_offset * wptr->CStep() + n_offset, *dstptr, utils::updiv(k_size, wptr->mBlockSize), n_size,
internal_k_offset * wptr->mN + n_offset, wptr->mDqBlockSize, dq_offset_idx, wptr->DQPtr<float>(),
wptr->CStep(), n_size, false);
wptr->CStep(), n_size, false, wptr->mN);
}
return BTLA_CODE::Success;
}
Expand Down Expand Up @@ -713,11 +713,6 @@ class WeightKBlockNInteger {
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
kernel::wrapper::DecompressKBlockS4S8Fp<T>::template forward<ISA_T, BTLA_DTYPE::S4_FULLRANGE>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S8) {
kernel::wrapper::DecompressKBlockS8S8Fp<T>::template forward<ISA_T>(
wptr->template WPtr<int8_t>() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad,
Expand Down Expand Up @@ -761,14 +756,6 @@ class WeightKBlockNInteger {
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr,
zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, float,
BTLA_DTYPE::S4_FULLRANGE>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr,
zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S8) {
kernel::wrapper::DecompressKBlockS8Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, float>(
wptr->template WPtr<int8_t>() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad,
Expand Down Expand Up @@ -802,14 +789,6 @@ class WeightKBlockNInteger {
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr,
zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
kernel::wrapper::DecompressKBlockS4Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, utils::bf16,
BTLA_DTYPE::S4_FULLRANGE>(
wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2 +
i * KPad / 2,
*dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize, sptr,
zptr != nullptr ? zptr + n_offset + i : nullptr, k_offset / _GemmCore_T::PACK_ROW,
wptr->mBlockSize / _GemmCore_T::PACK_ROW, NPad, tmpcache, cachesize);
} else if (wptr->mDType == BTLA_DTYPE::S8) {
kernel::wrapper::DecompressKBlockS8Fp<_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T, utils::bf16>(
wptr->template WPtr<int8_t>() + n_offset * KPad + k_offset * _GemmCore_T::NTILE + i * KPad,
Expand Down Expand Up @@ -873,14 +852,10 @@ class WeightKBlockNInteger {
auto KPad = wptr->mKPad;
auto bptr = wptr->template WPtr<utils::int4x2>() + n_offset * KPad / 2 + k_offset * _GemmCore_T::NTILE / 2;
int constexpr ColSize = _GemmCore_T::NTILE * _GemmCore_T::PACK_ROW;
assert(wptr->mDType == BTLA_DTYPE::S4_CLIP);
for (int i = 0; i < n_size; i += _GemmCore_T::NTILE) {
if (wptr->mDType == BTLA_DTYPE::S4_CLIP) {
kernel::wrapper::DecompressKBlockS4S8::template forward<ISA_T, BTLA_DTYPE::S4_CLIP>(
bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize);
} else if (wptr->mDType == BTLA_DTYPE::S4_FULLRANGE) {
kernel::wrapper::DecompressKBlockS4S8::template forward<ISA_T, BTLA_DTYPE::S4_FULLRANGE>(
bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize);
}
kernel::wrapper::DecompressKBlockS4S8::template forward<ISA_T, BTLA_DTYPE::S4_CLIP>(
bptr + i * KPad / 2, *dstptr + i * k_size, k_size / _GemmCore_T::PACK_ROW, ColSize, ColSize, ColSize);
}
*dststep = k_size;
return BTLA_CODE::Success;
Expand Down Expand Up @@ -916,9 +891,6 @@ class WeightKBlockNInteger {
if (quant_dtype == BTLA_DTYPE::S8) {
kernel::wrapper::QuantizeSignIntRowBlock::forward<ISA_T, BTLA_DTYPE::S8>(srcptr, dstptr, row, col, ld_src, ld_dst,
scales, zero_points, ptr->mBlockSize);
} else if (quant_dtype == BTLA_DTYPE::S4_FULLRANGE) {
kernel::wrapper::QuantizeSignIntRowBlock::forward<ISA_T, BTLA_DTYPE::S4_FULLRANGE>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize);
} else if (quant_dtype == BTLA_DTYPE::S4_CLIP) {
kernel::wrapper::QuantizeSignIntRowBlock::forward<ISA_T, BTLA_DTYPE::S4_CLIP>(
srcptr, dstptr, row, col, ld_src, ld_dst, scales, zero_points, ptr->mBlockSize);
Expand All @@ -932,7 +904,7 @@ class WeightKBlockNInteger {

static inline BTLA_CODE doCompress(const int8_t* srcptr, void* dstptr, int row, int col, int ld_src, int ld_dst,
BTLA_DTYPE quant_dtype) {
if (quant_dtype == BTLA_DTYPE::S4_CLIP || quant_dtype == BTLA_DTYPE::S4_FULLRANGE) {
if (quant_dtype == BTLA_DTYPE::S4_CLIP) {
return kernel::wrapper::CompressS8S4::forward<ISA_T>(srcptr, reinterpret_cast<utils::int4x2*>(dstptr), row, col,
ld_src, ld_dst);
} else if (quant_dtype == BTLA_DTYPE::F4_BNB || quant_dtype == BTLA_DTYPE::F4_NF4 ||
Expand Down Expand Up @@ -1051,7 +1023,7 @@ class WeightKBlockNFloat : public WeightKBlockNInteger<_GemmCore_T, ISA_T> {
auto internal_n_offset = n_offset + i;
auto internal_k_offset = k_offset / _GemmCore_T::PACK_ROW;
auto internal_kblock = wptr->mBlockSize / _GemmCore_T::PACK_ROW;
auto dq_offset_idx = wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1;
auto dq_offset_idx = static_cast<int>(wptr->mCorrection.mDQCorrectionBuf.mBufSize / sizeof(float) - 1);
if (wptr->mDType == BTLA_DTYPE::F4_NF4) {
kernel::wrapper::DecompressDqKBlockF4Fp<_DST_T, _GemmCore_T::PACK_ROW>::template forward<ISA_T,
BTLA_DTYPE::F4_NF4>(
Expand Down
54 changes: 31 additions & 23 deletions bestla/bestla/bestla_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@
#define CompileAMXINT8() (CompileAMX())
#endif

#if CompileBF16() || CompileFP16()
#include <immintrin.h>
#endif

namespace bestla {
namespace utils {
Expand Down Expand Up @@ -157,6 +155,18 @@ struct f8 {
x = v;
return *this;
}

inline float tofloat() const {
int32_t r = x + 127;
uint32_t tmp = bit_cast<uint32_t, int32_t>(r & 0xff);
tmp <<= 23;
return bit_cast<float, uint32_t>(tmp);
}

inline float mul(float src) const {
auto scale = tofloat();
return src * scale;
}
};

struct fp16 {
Expand Down Expand Up @@ -326,8 +336,6 @@ inline const char* bestla_dtype_str(BTLA_DTYPE dtype) {
return "unsigned_int8";
case BTLA_DTYPE::S4_CLIP:
return "int4_clip";
case BTLA_DTYPE::S4_FULLRANGE:
return "int4_fullrange";
case BTLA_DTYPE::F4_E2M1:
return "fp4_e2m1";
case BTLA_DTYPE::F4_BNB:
Expand Down Expand Up @@ -697,13 +705,13 @@ inline bool isFastExp() {
}
} // namespace utils

static float fp4_bnb_dequant_fp32_LUT[] = {
static float fp4_bnb_dequant_fp32_LUT alignas(64)[] = {
0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f, 0.33333333f,
0.50000000f, 0.16666667f, 0.25000000f, -1.f * 0.00000000f, -1.f * 5.208333333e-03f,
-1.f * 0.66666667f, -1.f * 1.00000000f, -1.f * 0.33333333f, -1.f * 0.50000000f, -1.f * 0.16666667f,
-1.f * 0.25000000f};

static float fp4_e2m1_dequant_fp32_LUT[] = {
static float fp4_e2m1_dequant_fp32_LUT alignas(64)[] = {
0.f,
0.010416666666666666f,
0.16666666666666666f,
Expand All @@ -722,27 +730,27 @@ static float fp4_e2m1_dequant_fp32_LUT[] = {
-1.f * 1.f,
};

static float nf4_dequant_fp32_LUT[] = {0.f,
-0.6961928009986877f,
-0.5250730514526367f,
-0.39491748809814453f,
-0.28444138169288635f,
-0.18477343022823334f,
-0.09105003625154495f,
-1.f,
0.07958029955625534f,
0.16093020141124725f,
0.24611230194568634f,
0.33791524171829224f,
0.44070982933044434f,
0.5626170039176941f,
0.7229568362236023f,
1.0f};
static float nf4_dequant_fp32_LUT alignas(64)[] = {0.f,
-0.6961928009986877f,
-0.5250730514526367f,
-0.39491748809814453f,
-0.28444138169288635f,
-0.18477343022823334f,
-0.09105003625154495f,
-1.f,
0.07958029955625534f,
0.16093020141124725f,
0.24611230194568634f,
0.33791524171829224f,
0.44070982933044434f,
0.5626170039176941f,
0.7229568362236023f,
1.0f};

// 8bit dynamic-tree-quantization map from bitsandbytes double-quant implementation.
// For more details pls refer
// (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
static float dq8_bnb_LUT[] = {
static float dq8_bnb_LUT alignas(64)[] = {
-0.99297f, -0.97891f, -0.96484f, -0.95078f, -0.93672f, -0.92266f, -0.90859f, -0.89453f, -0.88047f, -0.86641f,
-0.85234f, -0.83828f, -0.82422f, -0.81016f, -0.79609f, -0.78203f, -0.76797f, -0.75391f, -0.73984f, -0.72578f,
-0.71172f, -0.69766f, -0.68359f, -0.66953f, -0.65547f, -0.64141f, -0.62734f, -0.61328f, -0.59922f, -0.58516f,
Expand Down
Loading

0 comments on commit a90aea7

Please sign in to comment.