Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More quantize support #176

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 210 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5325,11 +5325,43 @@ static void dequantize_mul_mat_batch_q4_0_cuda_sparse(const void * vx, const dfl
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
// printf("ncols %d, nrows %d, src1_ncols %d, dst_ne0 %d\n", ncols, nrows, src1_ncols, dst_ne0);

dequantize_mul_mat_batch_sparse<QK4_0, QR4_0, dequantize_q4_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, src1_ncols, dst_ne0, lst, idx);

}
static void dequantize_mul_mat_batch_q4_1_cuda_sparse(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_ncols, int dst_ne0, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_batch_sparse<QK4_1, QR4_1, dequantize_q4_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, src1_ncols, dst_ne0, lst, idx);
}
static void dequantize_mul_mat_batch_q5_0_cuda_sparse(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_ncols, int dst_ne0, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_batch_sparse<QK5_0, QR5_0, dequantize_q5_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, src1_ncols, dst_ne0, lst, idx);
}
static void dequantize_mul_mat_batch_q5_1_cuda_sparse(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_ncols, int dst_ne0, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_batch_sparse<QK5_1, QR5_1, dequantize_q5_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, src1_ncols, dst_ne0, lst, idx);
}
static void dequantize_mul_mat_batch_q8_0_cuda_sparse(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_ncols, int dst_ne0, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);

dequantize_mul_mat_batch_sparse<QK8_0, QR8_0, dequantize_q8_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows, src1_ncols, dst_ne0, lst, idx);

}

static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
Expand Down Expand Up @@ -5466,6 +5498,38 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, lst, idx);
}
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream, const int *lst, const float *idx) {
GGML_ASSERT(ncols % QK4_1 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, lst, idx);
}
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream, const int *lst, const float *idx) {
GGML_ASSERT(ncols % QK5_0 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, lst, idx);
}
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream, const int *lst, const float *idx) {
GGML_ASSERT(ncols % QK5_1 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, lst, idx);
}
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream, const int *lst, const float *idx) {
GGML_ASSERT(ncols % QK8_0 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, lst, idx);
}


static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
Expand Down Expand Up @@ -5596,16 +5660,50 @@ static void dequantize_axpy_vec_q4_0_cuda(const void * vx, const dfloat * y, flo
dequantize_mul_mat_axpy<QK4_0, QR4_0, dequantize_q4_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
static void dequantize_axpy_sparse_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
// dequantize_mul_mat_axpy<QK4_0, QR4_0, dequantize_q4_0>
// <<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows);
dequantize_mul_mat_axpy_sparse<QK8_0, QR8_0, dequantize_q8_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, lst, idx);
}
static void dequantize_axpy_sparse_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse<QK4_0, QR4_0, dequantize_q4_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, lst, idx);
}
static void dequantize_axpy_sparse_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse<QK4_1, QR4_1, dequantize_q4_1>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, lst, idx);
}
static void dequantize_axpy_sparse_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse<QK5_0, QR5_0, dequantize_q5_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, lst, idx);
}
static void dequantize_axpy_sparse_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse<QK5_1, QR5_1, dequantize_q5_1>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, lst, idx);
}



static void dequantize_axpy_sparse_batch_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
Expand All @@ -5615,6 +5713,38 @@ static void dequantize_axpy_sparse_batch_q4_0_cuda(const void * vx, const dfloat
dequantize_mul_mat_axpy_sparse_batch<QK4_0, QR4_0, dequantize_q4_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, src1_rows, src1_ncols, lst, idx);
}
static void dequantize_axpy_sparse_batch_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse_batch<QK4_1, QR4_1, dequantize_q4_1>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, src1_rows, src1_ncols, lst, idx);
}
static void dequantize_axpy_sparse_batch_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse_batch<QK5_0, QR5_0, dequantize_q5_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, src1_rows, src1_ncols, lst, idx);
}
static void dequantize_axpy_sparse_batch_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse_batch<QK5_1, QR5_1, dequantize_q5_1>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, src1_rows, src1_ncols, lst, idx);
}
static void dequantize_axpy_sparse_batch_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, int src1_rows, int src1_ncols, cudaStream_t stream, int *lst, float *idx) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_axpy_sparse_batch<QK8_0, QR8_0, dequantize_q8_0>
<<<block_nums, block_dims, ncols*sizeof(float), stream>>>(vx, y, dst, ncols, nrows, src1_rows, src1_ncols, lst, idx);
}
static void convert_axpy_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
Expand Down Expand Up @@ -6951,6 +7081,18 @@ inline void ggml_cuda_op_mul_mat_batch_sparse(
case GGML_TYPE_Q4_0:
dequantize_mul_mat_batch_q4_0_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
case GGML_TYPE_Q4_1:
dequantize_mul_mat_batch_q4_1_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
case GGML_TYPE_Q5_0:
dequantize_mul_mat_batch_q5_0_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
case GGML_TYPE_Q5_1:
dequantize_mul_mat_batch_q5_1_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
case GGML_TYPE_Q8_0:
dequantize_mul_mat_batch_q8_0_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
case GGML_TYPE_F16:
convert_mul_mat_batch_f16_cuda_sparse(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, src1_ncols, dst->ne[0], stream, row_lookup, sparse_idx);
break;
Expand Down Expand Up @@ -7199,6 +7341,30 @@ inline void ggml_cuda_op_mul_mat_vec_sparse_q(
row_lookup, sparse_idx
);
break;
case GGML_TYPE_Q4_1:
mul_mat_vec_q4_1_q8_1_cuda(
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream,
row_lookup, sparse_idx
);
break;
case GGML_TYPE_Q5_0:
mul_mat_vec_q5_0_q8_1_cuda(
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream,
row_lookup, sparse_idx
);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q5_1_q8_1_cuda(
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream,
row_lookup, sparse_idx
);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_cuda(
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream,
row_lookup, sparse_idx
);
break;
default:
GGML_ASSERT(false && "Unsupported type");
break;
Expand Down Expand Up @@ -7563,6 +7729,42 @@ inline void ggml_cuda_op_dequantize_axpy(
dequantize_axpy_sparse_batch_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, ne10, src1_ncols, stream, row_lookup, sparse_idx);
}
break;
case GGML_TYPE_Q4_1:
if (sparse_idx == NULL) {
GGML_ASSERT(false);
} else if (ne11 == 1) {
dequantize_axpy_sparse_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, row_lookup, sparse_idx);
} else {
dequantize_axpy_sparse_batch_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, ne10, src1_ncols, stream, row_lookup, sparse_idx);
}
break;
case GGML_TYPE_Q5_0:
if (sparse_idx == NULL) {
GGML_ASSERT(false);
} else if (ne11 == 1) {
dequantize_axpy_sparse_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, row_lookup, sparse_idx);
} else {
dequantize_axpy_sparse_batch_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, ne10, src1_ncols, stream, row_lookup, sparse_idx);
}
break;
case GGML_TYPE_Q5_1:
if (sparse_idx == NULL) {
GGML_ASSERT(false);
} else if (ne11 == 1) {
dequantize_axpy_sparse_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, row_lookup, sparse_idx);
} else {
dequantize_axpy_sparse_batch_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, ne10, src1_ncols, stream, row_lookup, sparse_idx);
}
break;
case GGML_TYPE_Q8_0:
if (sparse_idx == NULL) {
GGML_ASSERT(false);
} else if (ne11 == 1) {
dequantize_axpy_sparse_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream, row_lookup, sparse_idx);
} else {
dequantize_axpy_sparse_batch_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, ne10, src1_ncols, stream, row_lookup, sparse_idx);
}
break;
case GGML_TYPE_F16:
if (sparse_idx == NULL) {
convert_axpy_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
Expand Down Expand Up @@ -8723,7 +8925,12 @@ static void ggml_cuda_mul_mat_sparse(const ggml_tensor * src0, const ggml_tensor
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_sparse_dequantized, false);
break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_sparse_q, true);
// ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
break;
default:
GGML_ASSERT(false && "unsupported type for sparse matrix multiplication");
Expand Down
Loading
Loading