diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 22456990792c3..a03fc695eabfa 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -294,12 +294,27 @@ def write_tensors(self): )) if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: - if self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + if self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3 and not any( + self.match_model_tensor_name(new_name, key, None) + for key in [ + gguf.MODEL_TENSOR.TOKEN_EMBD, + gguf.MODEL_TENSOR.OUTPUT, + ] + ): + data = gguf.quantize_q1_3(data) + assert data.dtype == np.uint8 + data_qtype = gguf.GGMLQuantizationType.Q1_3 + + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: data = gguf.quantize_bf16(data) assert data.dtype == np.int16 data_qtype = gguf.GGMLQuantizationType.BF16 - elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): + elif ( + self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 + or self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3 + and gguf.can_quantize_to_q8_0(data) + ): data = gguf.quantize_q8_0(data) assert data.dtype == np.uint8 data_qtype = gguf.GGMLQuantizationType.Q8_0 @@ -1401,6 +1416,12 @@ def write_tensors(self): class BitnetModel(Model): model_arch = gguf.MODEL_ARCH.BITNET + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None): + if ftype == gguf.LlamaFileType.GUESSED: + ftype = gguf.LlamaFileType.MOSTLY_Q1_3 + + super().__init__(dir_model, ftype, fname_out, is_big_endian, use_temp_file, eager, model_name) + def set_vocab(self): self._set_vocab_sentencepiece() diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 05df330c0846f..d35a79b8818fa 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,6 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, + { "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet 1.58b" }, { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, diff --git a/ggml-common.h b/ggml-common.h index a1a8246656cca..fd5d8a90a874b 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -137,6 +137,14 @@ typedef sycl::half2 ggml_half2; #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP +// 1.625 bpw for BitNet 1.58b models +#define QK1_3 64 +typedef struct { + uint8_t q[(QK1_3 - 4*QK1_3/64)/5]; // 5 elements per byte (3^5 = 243 < 256) + uint8_t qs[QK1_3/64]; // 4 elements per byte +} block_q1_3; +static_assert(sizeof(block_q1_3) == (QK1_3 - 4*QK1_3/64)/5 + QK1_3/64, "wrong q1_3 block size/padding"); + #define QK2_2 32 typedef struct { uint8_t qs[QK2_2 / 4]; // nibbles / quants @@ -339,6 +347,7 @@ typedef struct { } block_iq3_s; static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); +// 1.5625 bpw typedef struct { ggml_half d; uint8_t qs[QK_K/8]; @@ -1095,6 +1104,41 @@ GGML_TABLE_BEGIN(uint32_t, q22_grid, 256) 0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff, GGML_TABLE_END() +GGML_TABLE_BEGIN(uint32_t, q1_3_grid, 256) + 0xffffffff, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff, + 0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff000000, 0xff000001, + 0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff, + 0xff010000, 0xff010001, 0xff0101ff, 0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01, + 0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00, + 0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101, + 0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000, 0x00010001, 0x000101ff, 0x00010100, + 0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001, + 0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000, + 0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x0101ff01, + 0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00, + 0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff, + 0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100, + 0xff000101, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff, + 0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0000, + 0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff, + 0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01, + 0x000100ff, 0x00010000, 0x00010000, 0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff, + 0x01ffff00, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101, + 0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x01000001, 0x010001ff, + 0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001, + 0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000, + 0xffff0001, 0xffff01ff, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01, + 0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ff00, + 0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff, 0xff0101ff, 0xff010100, 0xff010101, + 0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100, + 0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff, + 0x00000100, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000, + 0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ff00ff, + 0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x01ff0101, 0x0100ffff, 0x0100ff00, + 0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff, + 0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101, +GGML_TABLE_END() + #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml-quants.c b/ggml-quants.c index f45ece1f25836..138b19fc46c98 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -1683,7 +1683,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6 // ===================== Helper functions // static inline int nearest_int(float fval) { - assert(fval <= 4194303.f); + assert(fabsf(fval) <= 4194303.f); float val = fval + 12582912.f; int i; memcpy(&i, &val, sizeof(int)); return (i & 0x007fffff) - 0x00400000; @@ -3366,6 +3366,133 @@ size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } +// ====================== 1.625 bpw (de)-quantization (BitNet 1.58b) + +void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict y, int64_t k) { + assert(k % QK1_3 == 0); + const int64_t nb = k / QK1_3; + static_assert(sizeof(y->q) % 4 == 0, "bad block_q1_3.q size"); + + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + for (int64_t i = 0; i < nb; ++i) { + uint8_t q[sizeof(y->q)] = {0}; + for (size_t j = 0; j < sizeof(y->q); ++j) { + for (size_t m = 0; m < 4; ++m) { + int xi = nearest_int(x[m]); + uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2; + q[j] += xt * pow3[m]; + } + x += 4; + } + for (size_t j = 0; j < sizeof(y->q); ++j) { + int xi = nearest_int(x[j]); + uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2; + q[j] += xt * pow3[4]; + q[j] = ((uint16_t)q[j] * 256) / pow3[5]; + q[j] += (uint8_t)(q[j] != 0); + y[i].q[j] = q[j]; + } + x += sizeof(y->q); + + for (size_t j = 0; j < sizeof(y->qs); ++j) { + uint8_t qb = 0; + for (size_t m = 0; m < 4; ++m) { + int xi = nearest_int(x[m]); + uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2; + qb += xt * pow3[m]; + } + x += 4; + qb = ((uint16_t)qb * 256) / pow3[5]; + qb += (uint8_t)(qb != 0); + y[i].qs[j] = qb; + } + } +} + +void quantize_row_q1_3(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK1_3 == 0); + block_q1_3 * restrict y = vy; + quantize_row_q1_3_reference(x, y, k); +} + +size_t quantize_q1_3(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_Q1_3, n_per_row); + quantize_row_q1_3(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + +void dequantize_row_q1_3(const block_q1_3 * restrict x, float * restrict y, int64_t k) { + assert(k % QK1_3 == 0); + const int64_t nb = k / QK1_3; + static_assert(sizeof(x->q) % 4 == 0, "bad block_q1_3.q size"); + +// #if defined(__SSE2__) +// __m128 vscale = _mm_set1_ps(scale); + +// for (int64_t i = 0; i < nb; ++i) { +// for (size_t j = 0; j < sizeof(x->q); j += 4) { +// __m128 q1 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 0]])); +// __m128 q2 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 1]])); +// __m128 q3 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 2]])); +// __m128 q4 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 3]])); +// q1 = _mm_mul_ps(q1, vscale); +// q2 = _mm_mul_ps(q2, vscale); +// q3 = _mm_mul_ps(q3, vscale); +// q4 = _mm_mul_ps(q4, vscale); + +// _mm_store_ps(y + 0, q1); +// _mm_store_ps(y + 4, q2); +// _mm_store_ps(y + 8, q3); +// _mm_store_ps(y + 12, q4); +// y += 16; +// } + +// for (size_t j = 0; j < sizeof(x->q); j += 4) { +// __m128i q5i = _mm_loadu_si32(x[i].q + j); +// q5i = _mm_cvtepi8_epi16(q5i); +// q5i = _mm_add_epi16(q5i, _mm_add_epi16(q5i, q5i)); +// q5i = _mm_srli_epi16(q5i, 8); +// q5i = _mm_sub_epi16(q5i, _mm_set1_epi16(1)); +// __m128 q5 = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(q5i)); +// q5 = _mm_mul_ps(q5, vscale); + +// _mm_store_ps(y, q5); +// y += 4; +// } + +// for (size_t j = 0; j < sizeof(x->qs); ++j) { +// __m128 q = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].qs[j]])); +// q = _mm_mul_ps(q, vscale); +// _mm_store_ps(y, q); +// y += 4; +// } +// } +// #else + for (int64_t i = 0; i < nb; ++i) { + for (size_t j = 0; j < sizeof(x->q); ++j) { + const int8_t * q = (const int8_t *) (q1_3_grid + x[i].q[j]); + for (int m = 0; m < 4; ++m) { + *y++ = (float) q[m]; + } + } + + for (size_t j = 0; j < sizeof(x->q); ++j) { + uint16_t q = x[i].q[j]; + *y++ = (float) ((int16_t)((q * 3) >> 8) - 1); + } + + for (size_t j = 0; j < sizeof(x->qs); ++j) { + const int8_t * q = (const int8_t *) (q1_3_grid + x[i].qs[j]); + for (int m = 0; m < 4; ++m) { + *y++ = (float) q[m]; + } + } + } +// #endif +} + // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { @@ -3802,7 +3929,58 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r #if defined(__AVX2__) __m256 acc = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { + int leftovers = nb % 2; + + for (int i = 0; i < nb - leftovers; i += 2) { + + const __m256 d0 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 0].d) ); + const __m256 d1 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 1].d) ); + + // assuming two consecutive blocks are contiguous AND aligned + __m128i xq16b = _mm_load_si128((const __m128i *) (x[i].qs)); + __m256i xq16 = MM256_SET_M128I(xq16b, xq16b); + __m256i xq8l0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1, + 4, -1, 4, -1, 4, -1, 4, -1, + 1, -1, 1, -1, 1, -1, 1, -1, + 0, -1, 0, -1, 0, -1, 0, -1)); + __m256i xq8h0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1, + 6, -1, 6, -1, 6, -1, 6, -1, + 3, -1, 3, -1, 3, -1, 3, -1, + 2, -1, 2, -1, 2, -1, 2, -1)); + __m256i xq8l1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(13, -1, 13, -1, 13, -1, 13, -1, + 12, -1, 12, -1, 12, -1, 12, -1, + 9, -1, 9, -1, 9, -1, 9, -1, + 8, -1, 8, -1, 8, -1, 8, -1)); + __m256i xq8h1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(15, -1, 15, -1, 15, -1, 15, -1, + 14, -1, 14, -1, 14, -1, 14, -1, + 11, -1, 11, -1, 11, -1, 11, -1, + 10, -1, 10, -1, 10, -1, 10, -1)); + __m256i shift = _mm256_set_epi16(64, 16, 4, 1, + 64, 16, 4, 1, + 64, 16, 4, 1, + 64, 16, 4, 1); + xq8l0 = _mm256_mullo_epi16(xq8l0, shift); + xq8h0 = _mm256_mullo_epi16(xq8h0, shift); + xq8l1 = _mm256_mullo_epi16(xq8l1, shift); + xq8h1 = _mm256_mullo_epi16(xq8h1, shift); + xq8l0 = _mm256_srai_epi16(xq8l0, 14); + xq8h0 = _mm256_srai_epi16(xq8h0, 14); + xq8l1 = _mm256_srai_epi16(xq8l1, 14); + xq8h1 = _mm256_srai_epi16(xq8h1, 14); + __m256i xq8_0 = _mm256_packs_epi16(xq8l0, xq8h0); + __m256i xq8_1 = _mm256_packs_epi16(xq8l1, xq8h1); + + __m256i yq8_0 = _mm256_lddqu_si256((const __m256i *) (y[i + 0].qs)); + __m256i yq8_1 = _mm256_lddqu_si256((const __m256i *) (y[i + 1].qs)); + + const __m256 q0 = mul_sum_i8_pairs_float(xq8_0, yq8_0); + const __m256 q1 = mul_sum_i8_pairs_float(xq8_1, yq8_1); + + acc = _mm256_fmadd_ps( d0, q0, acc ); + acc = _mm256_fmadd_ps( d1, q1, acc ); + } + + for (int i = nb - leftovers; i < nb; ++i) { const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) ); @@ -3826,7 +4004,7 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r xq8h = _mm256_srai_epi16(xq8h, 14); xq8 = _mm256_packs_epi16(xq8l, xq8h); - __m256i yq8 = _mm256_lddqu_si256((const __m256i*)(y[i].qs)); + __m256i yq8 = _mm256_lddqu_si256((const __m256i *) (y[i].qs)); const __m256 q = mul_sum_i8_pairs_float(xq8, yq8); acc = _mm256_fmadd_ps( d, q, acc ); @@ -10812,6 +10990,105 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { } #endif +void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + // assumed by the code below + assert(n % QK1_3 == 0); + static_assert(QK1_3 == 2 * QK8_0, "QK1_3 must be 2 times bigger than QK8_0"); + + const block_q1_3 * restrict x = vx; + const block_q8_0 * restrict y = vy; + + const int nb = n / QK1_3; + +#if defined(__AVX2__) + __m256 accumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + { + __m256i x0 = _mm256_set_epi32(q1_3_grid[x[i].q[7]], q1_3_grid[x[i].q[6]], + q1_3_grid[x[i].q[5]], q1_3_grid[x[i].q[4]], + q1_3_grid[x[i].q[3]], q1_3_grid[x[i].q[2]], + q1_3_grid[x[i].q[1]], q1_3_grid[x[i].q[0]]); + __m256i y0 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i].qs)); + + __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i].d)); + + __m256 q = mul_sum_i8_pairs_float(x0, y0); + + accumf = _mm256_fmadd_ps(d, q, accumf); + } + + { + __m256i x1 = _mm256_castsi128_si256(_mm_set_epi32(q1_3_grid[x[i].q[11]], q1_3_grid[x[i].q[10]], + q1_3_grid[x[i].q[9]], q1_3_grid[x[i].q[8]])); + __m256i x2 = _mm256_cvtepu8_epi16(_mm_maskload_epi32((const int32_t *) x[i].q, _mm_set_epi32(0, -1, -1, -1))); + __m256i y1 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i + 1].qs)); + + x2 = _mm256_mulhi_epu16(x2, _mm256_set1_epi16(3 << 8)); + x2 = _mm256_sub_epi16(x2, _mm256_set1_epi16(1)); + + // TODO: reduce shuffling + x2 = _mm256_packs_epi16(x2, _mm256_setzero_si256()); + x2 = _mm256_permute4x64_epi64(x2, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i x2_l = _mm_insert_epi32(_mm256_castsi256_si128(x2), q1_3_grid[x[i].qs[0]], 3); + x1 = _mm256_inserti128_si256(x1, x2_l, 1); + + __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i + 1].d)); + + __m256 q = mul_sum_i8_pairs_float(x1, y1); + + accumf = _mm256_fmadd_ps(d, q, accumf); + } + } + + *s = hsum_float_8(accumf); +#else + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int sum = 0; + for (int j = 0; j < 8; ++j) { + const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]); + for (int k = 0; k < 4; ++k) { + sum += xj[k] * (int16_t) y[2*i].qs[4*j + k]; + } + } + + sumf += GGML_FP16_TO_FP32(y[2*i].d) * sum; + sum = 0; + + for (int j = 0; j < 4; ++j) { + const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]); + for (int k = 0; k < 4; ++k) { + sum += xj[k] * (int16_t) y[2*i + 1].qs[4*j + k]; + } + } + + for (size_t j = 0; j < 12; ++j) { + uint16_t xj = x[i].q[j]; + xj = (xj * 3) >> 8; + sum += ((int16_t) xj - 1) * (int16_t) y[2*i + 1].qs[16 + j]; + } + + { + const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]); + for (int k = 0; k < 4; ++k) { + sum += (int16_t) xj[k] * (int16_t) y[2*i + 1].qs[28 + k]; + } + } + + sumf += GGML_FP16_TO_FP32(y[2*i + 1].d) * sum; + } + + *s = sumf; +#endif +} + void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -14488,6 +14765,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_Q1_3: case GGML_TYPE_Q2_2: case GGML_TYPE_I8: case GGML_TYPE_I16: diff --git a/ggml-quants.h b/ggml-quants.h index e159cef5f71b7..fe28132ad82bf 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -12,6 +12,7 @@ extern "C" { #endif // Quantization +void quantize_row_q1_3_reference(const float * GGML_RESTRICT x, block_q1_3 * GGML_RESTRICT y, int64_t k); void quantize_row_q2_2_reference(const float * GGML_RESTRICT x, block_q2_2 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); @@ -33,6 +34,7 @@ void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); +void quantize_row_q1_3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q2_2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -55,6 +57,7 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization +void dequantize_row_q1_3(const block_q1_3 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q2_2(const block_q2_2 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -81,6 +84,7 @@ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_ void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product +void ggml_vec_dot_q1_3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -120,6 +124,7 @@ size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q1_3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q2_2(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml.c b/ggml.c index 303b2f5636147..49f67146a868e 100644 --- a/ggml.c +++ b/ggml.c @@ -854,6 +854,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_Q1_3] = { + .type_name = "q1_3", + .blck_size = QK1_3, + .type_size = sizeof(block_q1_3), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q1_3, + .from_float = quantize_row_q1_3, + .from_float_reference = (ggml_from_float_t) quantize_row_q1_3_reference, + .vec_dot = ggml_vec_dot_q1_3_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", .blck_size = QK_K, @@ -10144,7 +10156,16 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { + if (ggml_nelements(src1) == 1) { + float scale = ((float *) src1->data)[0]; + for (int64_t ir = ith; ir < nr; ir += nth) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + } + ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); + } + } else if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); @@ -14169,6 +14190,7 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_3: case GGML_TYPE_Q2_2: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -21319,6 +21341,7 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { + case GGML_TYPE_Q1_3: result = quantize_q1_3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml.h b/ggml.h index 4ec555ccb39b2..ab8ca8770ab16 100644 --- a/ggml.h +++ b/ggml.h @@ -378,6 +378,7 @@ extern "C" { GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, GGML_TYPE_Q2_2 = 31, + GGML_TYPE_Q1_3 = 32, GGML_TYPE_COUNT, }; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 301200869b6f0..8a09efb427a01 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -924,6 +924,7 @@ class GGMLQuantizationType(IntEnum): IQ1_M = 29 BF16 = 30 Q2_2 = 31 + Q1_3 = 32 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -966,6 +967,7 @@ class LlamaFileType(IntEnum): MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors MOSTLY_Q2_2 = 33 # except 1d tensors + MOSTLY_Q1_3 = 34 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1040,6 +1042,7 @@ def get_type(val: Any) -> GGUFValueType: GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.BF16: (1, 2), + GGMLQuantizationType.Q1_3: (64, 12 + 1), } diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index b22eec1661ce7..a2beb0d53375a 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -121,3 +121,53 @@ def quantize_q8_0(data: np.ndarray): return __quantize_q8_0_lazy(data) else: return __quantize_q8_0_array(data) + + +__q1_3_block_size, __q1_3_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q1_3] + + +def __quantize_q1_3_shape_change(s: tuple[int, ...]) -> tuple[int, ...]: + return (*s[:-1], s[-1] // __q1_3_block_size * __q1_3_type_size) + + +def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray: + shape = n.shape + assert shape[-1] % __q1_3_block_size == 0 + + n_blocks = n.size // __q1_3_block_size + + blocks = n.reshape((n_blocks, __q1_3_block_size)).astype(np.float32, copy=False) + + # assuming the weights are pre-scaled + blocks = (np.sign(blocks).astype(np.int8) + 1).view(np.uint8) + q48, rest = np.hsplit(blocks, (48,)) + q12, q4 = np.hsplit(rest, (12,)) + + pow3 = np.array([1, 3, 9, 27]) + q48 = q48.reshape((n_blocks, 12, 4)) + q48 = np.sum(q48 * pow3.reshape((1, 1, 4)), axis=2, keepdims=True).reshape((n_blocks, 12)) + q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True) + q48 = q48 + (q12 * 81) + q = np.concatenate([q48, q4], axis=1); + q = ((q.astype(np.uint16) * 256) // 243).astype(np.uint8) + q = np.where(q != 0, q + 1, 0); + + return q.reshape(__quantize_q1_3_shape_change(shape)) + + +def __quantize_q1_3_array(n: np.ndarray) -> np.ndarray: + return __apply_over_grouped_rows(__quantize_q1_3_rows, arr=n, otype=np.uint8, oshape=__quantize_q1_3_shape_change(n.shape)) + + +__quantize_q1_3_lazy = LazyNumpyTensor._wrap_fn( + __quantize_q1_3_array, + meta_noop=(np.uint8, __quantize_q1_3_shape_change), +) + + +def quantize_q1_3(data: np.ndarray): + if type(data) is LazyNumpyTensor: + return __quantize_q1_3_lazy(data) + else: + return __quantize_q1_3_array(data) + diff --git a/llama.cpp b/llama.cpp index 85182f4bbeea1..7a44be3b0c89c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3885,7 +3885,8 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2"; + case LLAMA_FTYPE_MOSTLY_Q1_3: return "Q1_3 - 1.625 bpw for BitNet 1.58b"; + case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2 - 2.000 bpw for BitNet 1.58b"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: @@ -15462,6 +15463,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llama_ftype ftype = params->ftype; switch (params->ftype) { + case LLAMA_FTYPE_MOSTLY_Q1_3: default_type = GGML_TYPE_Q1_3; break; case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break; case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; diff --git a/llama.h b/llama.h index 7a2e0e31c1bb4..6b0c8a34a88a8 100644 --- a/llama.h +++ b/llama.h @@ -157,6 +157,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q2_2 = 33, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q1_3 = 34, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file };