diff --git a/ggml-quants.c b/ggml-quants.c index a157301..1019eb9 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -5,7 +5,7 @@ #include #include #include - +#include #ifdef __ARM_NEON // if YCM cannot find , make a symbolic link to it, for example: @@ -2583,6 +2583,36 @@ void ggml_axpy_q4_0_q8_0(const int n, const void * restrict vx, const void * res } #endif } +void ggml_axpy_q4_1_q8_1(const int n, const void * restrict vx, const void * restrict vy, const void * restrict vz, int8_t alpha, ggml_fp16_t scale) { + const int qk = QK8_1; + const int nb = n / qk; + assert(n % qk == 0); + assert(nb % 2 == 0); + + const block_q4_1 * restrict x = vx; + + float *res = (float *)vz; + float scale_fp32 = GGML_FP16_TO_FP32(scale); + for (int i = 0; i < nb; i++) { + float result_scale = GGML_FP16_TO_FP32(x[i].d) * scale_fp32; + int offset = i * QK4_1; + float m = GGML_FP16_TO_FP32(x[i].m); + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F); + const int v1 = (x[i].qs[j] >> 4); + float x0 = v0*GGML_FP16_TO_FP32(x[i].d) + m; + float x1 = v1*GGML_FP16_TO_FP32(x[i].d) + m; + float y0 = scale_fp32 * alpha; + // res[offset + j] = res[offset + j] + x0*y0; + // res[offset + j + qk/2] = res[offset + j + qk/2] + x1*y0; + // float res = ((int)v0*(int)alpha)*GGML_FP16_TO_FP32(x[i].d)*scale_fp32+ m*y0; + // float res = ((int)v0*(int)alpha)*result_scale+ m*y0; + res[offset + j] = res[offset + j] + ((int)v0*(int)alpha)*result_scale+ m*y0; + res[offset + j + qk/2] = res[offset + j + qk/2] + ((int)v1*(int)alpha)*result_scale+ m*y0; + } + } +} void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) { diff --git a/ggml-quants.h b/ggml-quants.h index 07b02b9..f622374 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -212,6 +212,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int void ggml_axpy_q8_0_q8_0(const int n, const void * restrict vx, const void * restrict vy, const void * restrict vz, int8_t alpha, ggml_fp16_t scale); void ggml_axpy_q4_0_q8_0(const int n, const void * restrict vx, const void * restrict vy, const void * restrict vz, int8_t alpha, ggml_fp16_t scale); +void ggml_axpy_q4_1_q8_1(const int n, const void * restrict vx, const void * restrict vy, const void * restrict vz, int8_t alpha, ggml_fp16_t scale); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); diff --git a/ggml.c b/ggml.c index 0f84a7e..ea4ae91 100644 --- a/ggml.c +++ b/ggml.c @@ -14755,6 +14755,162 @@ static void ggml_compute_forward_mul_mat_axpy_q8_0( _freea(vec); #endif } +static void ggml_compute_forward_mul_mat_axpy_q4_1( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + + // const bool src1_cont = ggml_is_contiguous(src1); + + // ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type vec_dot_type = type_traits[type].vec_dot_type; + //TODO: not hardcode + vec_dot_type = GGML_TYPE_Q8_0; + ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + + const float threshold = sparse_pred_threshold; + + // GGML_ASSERT(ne0 == ne01); + // GGML_ASSERT(ne1 == ne11); + // GGML_ASSERT(ne2 == ne12); + // GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + // const int64_t r2 = ne12/ne02; + // const int64_t r3 = ne13/ne03; + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + if (params->type == GGML_TASK_INIT) { + ggml_set_zero(dst); + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } + } + } + } + atomic_store(params->aic, 0); + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + ggml_fp16_t* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); + + struct ggml_tensor *src2 = dst->src[2]; + + // parallelize by src0 rows + // const int64_t dr = (src2->ne[0] + 8*nth - 1)/(8*nth); + const int64_t dr = (src2->ne[0] + nth - 1)/(nth); + const int nr = ggml_nrows(src0); + + const int64_t ir10 = dr*ith; + // const int64_t ir11 = MIN(ir10 + dr, src2->ne[0]); + + // src1 rows + const int64_t nr1 = ne11*ne12*ne13; + float *idx = src2->data; + int idx_row_size = src2->nb[1]; + int *gid = (int *)(dst->src[3]->data); + // printf("down %d up %d ne00 %d\n", ir10, ir11, ne00); + +#if defined(_MSC_VER) + float* vec = (float *)_malloca(ne00 * 4 * sizeof(float)); +#else + float vec[ne00*4]; +#endif + void *vy = vec; + char* src0_row = (char *) src0->data; + for (int col_idx = 0; col_idx < nr1; col_idx++) { + // const block_q8_0 * restrict nerual = wdata; + const block_q8_0 *restrict nerual = (block_q8_0 *)((char *)wdata + col_idx * row_size); + idx = (float *)((char *)src2->data + col_idx * idx_row_size); + memset(vy, 0, ne00 * 4); + // while(true) { + // const int ir0 = atomic_fetch_add(params->aic, dr); + for (int64_t ir1 = ir10; ir1 < ir10 + dr; ir1++) + { + if (ir1 >= nr) + break; + if (gid[ir1] == 1) + continue; + if (idx[ir1] < threshold) + continue; + int bid = ir1 / QK8_1; + int qsid = ir1 % QK8_1; + int b = (int)nerual[bid].qs[qsid]; + if (b == 0) + continue; + ggml_fp16_t d = nerual[bid].d; + ggml_axpy_q4_1_q8_1(ne00, src0_row + nb01 * ir1, vy, vy, b, d); + } + // if (ir0 + dr >= nr) + // break; + // } + + // float *res = (float *)(dst->data); + float *res = (float *)((char *)(dst->data) + col_idx * nb1); + float *tmp = (float *)vy; + int i; + + // 计算剩余的元素个数 + int remainder = ne00 % 8; +#if defined(__AVX2__) + // 使用AVX指令进行向量化计算 + for (i = 0; i < ne00 - remainder; i += 8) + { + __m256 res_vec = _mm256_loadu_ps(res + i); // 加载res中的8个浮点数 + __m256 tmp_vec = _mm256_loadu_ps(tmp + i); // 加载tmp中的8个浮点数 + __m256 result = _mm256_add_ps(res_vec, tmp_vec); // 执行加法运算 + _mm256_storeu_ps(res + i, result); // 存储结果到res中 + } + + // 处理剩余的元素 + for (i = ne00 - remainder; i < ne00; i++) + { + res[i] += tmp[i]; + } +#else + for (i = 0; i < ne00; i++) { + res[i] += tmp[i]; + } +#endif + } +#if defined(_MSC_VER) + _freea(vec); +#endif +} static void ggml_compute_forward_mul_mat_axpy_q4_0( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -15216,6 +15372,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm case GGML_TYPE_Q4_0: ggml_compute_forward_mul_mat_axpy_q4_0(params, tensor->src[0], tensor->src[1], tensor); break; + case GGML_TYPE_Q4_1: + ggml_compute_forward_mul_mat_axpy_q4_1(params, tensor->src[0], tensor->src[1], tensor); + break; case GGML_TYPE_Q8_0: ggml_compute_forward_mul_mat_axpy_q8_0(params, tensor->src[0], tensor->src[1], tensor); break;