From d962a56baa4c0591789d25d3f78817e50d487628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 11 Jun 2024 15:43:39 +0200 Subject: [PATCH 1/6] CUDA: faster q2_K, q3_K MMQ + int8 tensor cores --- ggml-cuda.cu | 6 +- ggml-cuda/argsort.cu | 1 + ggml-cuda/common.cuh | 1 + ggml-cuda/mmq.cuh | 672 +++++++++++++++++++++++------------------- ggml-cuda/quantize.cu | 58 +++- ggml-cuda/softmax.cu | 1 + ggml-cuda/vecdotq.cuh | 36 ++- 7 files changed, 437 insertions(+), 338 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 64d3b6747fc41..593fa4cdaa514 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -188,13 +188,15 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; + info.devices[id].nsm = prop.multiProcessorCount; + info.devices[id].smpb = prop.sharedMemPerBlock; #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + info.devices[id].smpbo = prop.sharedMemPerBlock; info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; #else + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - info.devices[id].smpb = prop.sharedMemPerBlock; - info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { diff --git a/ggml-cuda/argsort.cu b/ggml-cuda/argsort.cu index 1641440617779..15757ca18e4d7 100644 --- a/ggml-cuda/argsort.cu +++ b/ggml-cuda/argsort.cu @@ -73,6 +73,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co const dim3 block_nums(1, nrows, 1); const size_t shared_mem = ncols_pad * sizeof(int); + // FIXME: this limit could be raised by ~2-4x on Ampere or newer GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 7f4764d60e854..3f51548d01c26 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -661,6 +661,7 @@ struct ggml_cuda_device_info { int cc; // compute capability int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory size_t total_vram; diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 01e2086b41646..594f0742d144f 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -10,10 +10,10 @@ #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) typedef void (*load_tiles_mmq_t)( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); typedef void (*vec_dot_mmq_t)( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1); @@ -25,9 +25,8 @@ static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected b static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); struct tile_x_sizes { - int ql; + int qs; int dm; - int qh; int sc; }; @@ -67,16 +66,16 @@ static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) { #endif // __CUDA_ARCH__ >= CC_VOLTA #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0} -#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0} -#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0} -#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0} -#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0} -#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4} -#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4} -#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} +#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} +#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0} +#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0} +#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0} +#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} +#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4} +#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} #define GET_TILE_X_SIZES_BODY \ return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \ @@ -89,7 +88,7 @@ static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) { type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \ type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \ type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \ - tile_x_sizes{0, 0, 0, 0} + tile_x_sizes{0, 0, 0} static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) { GET_TILE_X_SIZES_BODY; @@ -103,9 +102,9 @@ static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const int kbx = threadIdx.x / QI4_0; const int kqsx = threadIdx.x % QI4_0; @@ -122,7 +121,7 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); + x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); } const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; @@ -144,10 +143,9 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const float * x_df = (const float *) x_dm; const int * y_qs = (const int *) y + 4; @@ -172,7 +170,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], + (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } @@ -180,10 +178,9 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -205,7 +202,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( const int k = k0 + mma_A::get_k(l) % QI4_0; const int shift = 4*(mma_A::get_k(l) / QI4_0); - A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808); + A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808); } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -243,9 +240,9 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q4_1( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const int kbx = threadIdx.x / QI4_1; const int kqsx = threadIdx.x % QI4_1; @@ -260,7 +257,7 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); } const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; @@ -282,10 +279,9 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -309,7 +305,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], + (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } @@ -317,10 +313,9 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -341,7 +336,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( const int k = k0 + mma_A::get_k(l) % QI4_0; const int shift = 4*(mma_A::get_k(l) / QI4_0); - A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F; + A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F; } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -380,9 +375,9 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q5_0( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const int kbx = threadIdx.x / QI5_0; const int kqsx = threadIdx.x % QI5_0; @@ -407,7 +402,7 @@ template static __device__ __forceinlin qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; + x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 @@ -416,7 +411,7 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; + x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; } const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; @@ -439,10 +434,9 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const float * x_dmf = (const float *) x_dm; const int * y_qs = (const int *) y + 4; @@ -468,17 +462,16 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a( } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -499,7 +492,7 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( const int i = i0 + mma_A::get_i(l); const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0; - A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; + A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k]; } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -537,9 +530,9 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q5_1( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const int kbx = threadIdx.x / QI5_1; const int kqsx = threadIdx.x % QI5_1; @@ -563,7 +556,7 @@ template static __device__ __forceinlin qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; + x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 @@ -571,7 +564,7 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; + x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; } const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; @@ -593,10 +586,9 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -621,17 +613,16 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a( } sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -651,7 +642,7 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( const int i = i0 + mma_A::get_i(l); const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1; - A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k]; + A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k]; } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -690,10 +681,9 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q8_0( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const int kbx = threadIdx.x / QI8_0; const int kqsx = threadIdx.x % QI8_0; @@ -709,7 +699,7 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); + x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); } const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; @@ -731,10 +721,9 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); const float * x_dmf = (const float *) x_dm; const int * y_qs = (const int *) y + 4; @@ -749,7 +738,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int i = i0 + threadIdx.x; sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], + (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + k0/QI8_1]); } } @@ -757,10 +746,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -781,7 +769,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int i = i0 + mma_A::get_i(l); const int k = k0 + mma_A::get_k(l); - A.x[l] = x_ql[i*(WARP_SIZE + 1) + k]; + A.x[l] = x_qs[i*(WARP_SIZE + 1) + k]; } #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -819,9 +807,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q2_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); const int kbx = threadIdx.x / QI2_K; const int kqsx = threadIdx.x % QI2_K; @@ -836,48 +823,33 @@ template static __device__ __forceinlin const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } + const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx); - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = 0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { - int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int l = 0; l < QR3_K; ++l) { + const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4; - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd; + int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4)); + x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); + x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE); - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); + x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4); - - x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4)); + const int sc_m = bxi->scales[kqsx]; + x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = bxi->dm * make_half2(sc_m & 0x0F, sc_m >> 4); } } template -static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_qh); - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -887,30 +859,93 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const int kbx = k0 / QI2_K; - const int ky = (k0 % QI2_K) * QR2_K; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], + &x_dm[i*(WARP_SIZE + 1) + k0], y_ds[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); + } + } +} - int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); - const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + typedef mma_int_A_I16K4 mma_A; + typedef mma_int_B_J8K4 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + + mma_A A[2]; + float dA[mma_C::ne/2][2]; + float mA[mma_C::ne/2][2]; #pragma unroll - for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { - v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; - } + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int shift = 2*mma_A::get_k(l); - const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303; + A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303; + } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales, - x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + +#pragma unroll + for (int kk = 0; kk < 2; ++kk) { + const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]); + + dA[l][kk] = dm.x; + mA[l][kk] = dm.y; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C[2]; + mma_B B[2]; + float dB[mma_C::ne/2]; + float sB[mma_C::ne/2][2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0]; + B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + const half2 tmp = y_ds[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; + dB[l] = __low2float(tmp); + + const int8_t * sBi = (const int8_t *) &tmp.y; + sB[l][0] = (127.0f/8.0f)*sBi[0]; + sB[l][1] = (127.0f/8.0f)*sBi[1]; + } + + C[0].mma_K4(A[0], B[0]); + C[1].mma_K4(A[1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*dA[l/2][0] + C[1].x[l]*dA[l/2][1] + mA[l/2][0]*sB[l%2][0] + mA[l/2][1]*sB[l%2][1])*dB[l%2]; } } } template static __device__ __forceinline__ void load_tiles_q3_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { const int kbx = threadIdx.x / QI3_K; @@ -926,7 +961,21 @@ template static __device__ __forceinlin const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); + const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx); + const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); + +#pragma unroll + for (int l = 0; l < QR3_K; ++l) { + const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; + const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; + + int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2)); + x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); + + x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; + } } const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; @@ -946,20 +995,6 @@ template static __device__ __forceinlin x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2); - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2); - - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2)); - } - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); @@ -987,13 +1022,13 @@ template static __device__ __forceinlin } template -static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const float * x_dmf = (const float *) x_dm; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -1008,31 +1043,97 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; - int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales, + x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]); + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, + const int * __restrict__ y, float * __restrict__ sum, const int & k0) { + + typedef mma_int_A_I16K4 mma_A; + typedef mma_int_B_J8K4 mma_B; + typedef mma_int_C_I16J8 mma_C; + + const float * x_df = (const float *) x_dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = threadIdx.y*mma_A::I; + static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); + + mma_A A[2]; + int scA[mma_C::ne/2][2]; + float dA[mma_C::ne/2]; #pragma unroll - for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); - const int shift = 2 * ((ky % 32) / 8); - const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + for (int l = 0; l < mma_A::ne; ++l) { + const int i = i0 + mma_A::get_i(l); + const int k = QR3_K*k0 + mma_A::get_k(l); - const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); - const int vlh = (vh << 2) & 0x04040404; + A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F; + A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F; + A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404); + A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404); + } - v[l] = __vsubss4(vll, vlh); - } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales, - x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]); + const int kbx = k0 / QI3_K; + const int ky = (k0 % QI3_K) * QR3_K; + const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + + scA[l][0] = sc[0]; + scA[l][1] = sc[1]; + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + mma_C::get_i(2*l); + + dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K]; + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { + mma_C C[2]; + mma_B B[2]; + float dB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int j = j0 + mma_B::get_j(l); + const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE; + + B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0]; + B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K]; + } +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; + } + + C[0].mma_K4(A[0], B[0]); + C[1].mma_K4(A[1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2]; } } } template static __device__ __forceinline__ void load_tiles_q4_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); const int kbx = 0; // threadIdx.x / QI4_K const int kqsx = threadIdx.x; // threadIdx.x % QI4_K @@ -1047,7 +1148,7 @@ template static __device__ __forceinlin const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); } const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 @@ -1090,11 +1191,9 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_qh); - const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -1109,7 +1208,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8); sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8, + &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8, x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]); } } @@ -1117,11 +1216,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; @@ -1143,7 +1240,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const int i = i0 + mma_A::get_i(l); const int k = k0 + mma_A::get_k(l); - A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F; + A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F; } #pragma unroll @@ -1207,9 +1304,8 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q5_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); const int kbx = 0; // threadIdx.x / QI5_K const int kqsx = threadIdx.x; // threadIdx.x % QI5_K @@ -1236,8 +1332,8 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; } const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 @@ -1280,11 +1376,9 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_qh); - const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -1299,7 +1393,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8, + &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8, x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]); } } @@ -1307,11 +1401,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; typedef mma_int_C_I16J8 mma_C; @@ -1333,7 +1425,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const int i = i0 + mma_A::get_i(l); const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l); - A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k]; + A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k]; } #pragma unroll @@ -1397,9 +1489,8 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q6_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); const int kbx = 0; // threadIdx.x / QI6_K const int kqsx = threadIdx.x; // threadIdx.x % QI6_K @@ -1426,8 +1517,8 @@ template static __device__ __forceinlin const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0; const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); } const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 @@ -1463,11 +1554,9 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_qh); - const float * x_dmf = (const float *) x_dm; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -1483,7 +1572,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc, + &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc, x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]); } } @@ -1491,11 +1580,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; typedef mma_int_C_I16J8 mma_C; @@ -1517,8 +1604,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int i = i0 + mma_A::get_i(l); const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l); - A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0]; - A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K]; + A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0]; + A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K]; } #pragma unroll @@ -1638,142 +1725,104 @@ struct mmq_type_traits; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; -#ifdef INT8_MMA_AVAILABLE - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma; - static constexpr mmq_write_back_t write_back = mmq_write_back_mma; -#else - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; -#ifdef INT8_MMA_AVAILABLE - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma; - static constexpr mmq_write_back_t write_back = mmq_write_back_mma; -#else - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE + static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; -#ifdef INT8_MMA_AVAILABLE - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma; - static constexpr mmq_write_back_t write_back = mmq_write_back_mma; -#else - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE + static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; -#ifdef INT8_MMA_AVAILABLE - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma; - static constexpr mmq_write_back_t write_back = mmq_write_back_mma; -#else - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE + static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; -#ifdef INT8_MMA_AVAILABLE - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma; - static constexpr mmq_write_back_t write_back = mmq_write_back_mma; -#else - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE + static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; + static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; + static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; -#ifdef INT8_MMA_AVAILABLE - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mma; - static constexpr mmq_write_back_t write_back = mmq_write_back_mma; -#else - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE + static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; -#ifdef INT8_MMA_AVAILABLE - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mma; - static constexpr mmq_write_back_t write_back = mmq_write_back_mma; -#else - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE + static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; -#ifdef INT8_MMA_AVAILABLE - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mma; - static constexpr mmq_write_back_t write_back = mmq_write_back_mma; -#else - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a; - static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE + static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; static int mmq_need_sum(const ggml_type type_x) { switch (type_x) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - return true; + return 1; case GGML_TYPE_Q5_0: - return false; + return 0; case GGML_TYPE_Q5_1: - return true; + return 1; case GGML_TYPE_Q8_0: + return 0; case GGML_TYPE_Q2_K: + return 2; case GGML_TYPE_Q3_K: - return false; + return 0; case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - return true; + return 1; case GGML_TYPE_Q6_K: - return false; + return 0; default: GGML_ASSERT(false); break; @@ -1790,7 +1839,7 @@ template #if __CUDA_ARCH__ >= CC_VOLTA __launch_bounds__(WARP_SIZE*nwarps, 1) #else - __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2) + __launch_bounds__(WARP_SIZE*nwarps, 2) #endif // __CUDA_ARCH__ >= CC_VOLTA #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( @@ -1809,16 +1858,21 @@ static __global__ void mul_mat_q( constexpr int mmq_y = get_mmq_y_device(mmq_x); constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; - constexpr mmq_write_back_t write_back = mmq_type_traits::write_back; + +#ifdef INT8_MMA_AVAILABLE + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); extern __shared__ char data_mul_mat_q[]; - int * tile_x_ql = (int *) data_mul_mat_q; - half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql); - int * tile_x_qh = (int *) (tile_x_dm + txs.dm); - int * tile_x_sc = (int *) (tile_x_qh + txs.qh); + int * tile_x_qs = (int *) data_mul_mat_q; + half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs); + int * tile_x_sc = (int *) (tile_x_dm + txs.dm); int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)] const int blocks_per_row_x = ne00 / qk; @@ -1834,7 +1888,7 @@ static __global__ void mul_mat_q( for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { - load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01); + load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01); #pragma unroll for (int kr = 0; kr < qr; ++kr) { @@ -1850,7 +1904,7 @@ static __global__ void mul_mat_q( // #pragma unroll // unrolling this loop causes too much register pressure for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { - vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0); + vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0); } __syncthreads(); @@ -1867,6 +1921,19 @@ struct mmq_args { int64_t ne0; }; +constexpr int mmq_get_nwarps(int mmq_x) { + return mmq_x >= 32 ? 8 : 4; +} + +static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) { + const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); + const int nwarps = mmq_get_nwarps(mmq_x); + + const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); + return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int)); +} + template static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); @@ -1878,10 +1945,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); - const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); - const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int); - const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); - const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int)); + const int shmem = mmq_get_shmem(type, mmq_x, mmq_y); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1905,9 +1969,10 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { template void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int nsm = ggml_cuda_info().devices[id].nsm; - const int cc = ggml_cuda_info().devices[id].cc; + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + const int cc = ggml_cuda_info().devices[id].cc; + const int smpbo = ggml_cuda_info().devices[id].smpbo; const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc, mmq_x_max); @@ -1920,7 +1985,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x; const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm; - if (nwaves < nwaves_best) { + if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) { mmq_x_best = mmq_x; nwaves_best = nwaves; } @@ -1928,54 +1993,55 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { switch (mmq_x_best) { case 8: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 16: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 24: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 32: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 40: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 48: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 56: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 64: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 72: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 80: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 88: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 96: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 104: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 112: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 120: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; case 128: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(args, stream); break; default: + fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); GGML_ASSERT(false); break; } diff --git a/ggml-cuda/quantize.cu b/ggml-cuda/quantize.cu index b4678682238d3..8d61d8bd625f5 100644 --- a/ggml-cuda/quantize.cu +++ b/ggml-cuda/quantize.cu @@ -1,4 +1,5 @@ #include "quantize.cuh" +#include #include static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { @@ -37,7 +38,7 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -template +template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { @@ -60,24 +61,48 @@ static __global__ void quantize_mmq_q8_1( amax = warp_reduce_max(amax); - float sum; - if (need_sum) { - sum = warp_reduce_sum(xi); - } - const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; - if (iqs % QK8_1 != 0) { - return; - } + static_assert(need_sum >= 0 && need_sum <= 2, "Invalid need_sum value."); + if (need_sum == 0) { + if (iqs % QK8_1 != 0) { + return; + } + + ((float *) y[ib].ds)[iqs/QK8_1] = d; + } else if (need_sum == 1) { + const float sum = warp_reduce_sum(xi); + + if (iqs % QK8_1 != 0) { + return; + } - if (need_sum) { y[ib].ds[iqs/QK8_1] = make_half2(d, sum); } else { - ((float *) y[ib].ds)[iqs/QK8_1] = d; + float sum = xi; + + // Calculate sum per 16 values: +#pragma unroll + for (int mask = 8; mask > 0; mask >>= 1) { + sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + } + + if (iqs % (QK8_1/2) != 0) { + return; + } + + int8_t * si = (int8_t *) &y[ib].ds[iqs/QK8_1].y; + const int tmp = roundf(amax == 0.0f ? 0.0f : -8*sum/amax); + si[(iqs % QK8_1)/(QK8_1/2)] = min(tmp, 127); + + if (iqs % QK8_1 != 0) { + return; + } + + reinterpret_cast(y[ib].ds[iqs/QK8_1].x) = d; } } @@ -104,9 +129,14 @@ void quantize_mmq_q8_1_cuda( const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, kx1, channels); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - if (mmq_need_sum(type_x)) { - quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); + const int need_sum = mmq_need_sum(type_x); + if (need_sum == 0) { + quantize_mmq_q8_1<0><<>>(x, vy, kx0, kx1, kx0_padded); + } else if (need_sum == 1) { + quantize_mmq_q8_1<1><<>>(x, vy, kx0, kx1, kx0_padded); + } else if (need_sum == 2) { + quantize_mmq_q8_1<2><<>>(x, vy, kx0, kx1, kx0_padded); } else { - quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); + GGML_ASSERT(false); } } diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index ce64f2f2ce28b..c24abae1f138c 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // FIXME: this limit could be raised by ~2-4x on Ampere or newer if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index b9573a7c7d053..6bf4d6b7aa160 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -265,36 +265,32 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( // contiguous u/y values static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, - const half2 & dm2, const float & d8) { + const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const half2 & ds8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi_d = 0; - int sumi_m = 0; + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + const float d8 = __low2float(ds8); + const int8_t * s8i = (const int8_t *) &ds8.y; #pragma unroll for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { - int sumi_d_sc = 0; - - const int sc = scales[i0 / (QI8_1/2)]; - - // fill int with 4x m - int m = sc >> 4; - m |= m << 8; - m |= m << 16; + const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); + int sumi_d = 0; + const int vi0 = v[i0/(QI8_1/2)]; #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product - sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; + sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product } - sumi_d += sumi_d_sc * (sc & 0xF); + sumf_d += dm2f.x * sumi_d; + sumf_m += dm2f.y * s8i[i0/(QI8_1/2)]; } - const float2 dm2f = __half22float2(dm2); - - return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); + return d8*(sumf_d + (127.0f/8.0f)*sumf_m); #else NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -352,8 +348,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { int sumi_sc = 0; +#pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404); + sumi_sc = __dp4a(vi, u[i], sumi_sc); // SIMD dot product } sumi += sumi_sc * scales[i0 / (QI8_1/2)]; From 87099452ede829d57ee20b517e7ce4747def5e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 13 Jun 2024 18:06:07 +0200 Subject: [PATCH 2/6] try CI fix --- ggml-cuda/common.cuh | 4 ++++ ggml-cuda/mmq.cuh | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 3f51548d01c26..de7c2e4349ede 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -331,6 +331,10 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int #define FP16_AVAILABLE #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 +#define FAST_FP16_AVAILABLE +#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 + #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 594f0742d144f..c454f3f0ac60d 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -839,7 +839,14 @@ template static __device__ __forceinlin } const int sc_m = bxi->scales[kqsx]; - x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = bxi->dm * make_half2(sc_m & 0x0F, sc_m >> 4); +#ifdef FAST_FP16_AVAILABLE + const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4)); +#else + const float2 bxi_dmf = __half22float2(bxi->dm); + const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); +#endif // FAST_FP16_AVAILABLE + + x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik; } } From 46b4054e6ef4f625e8d28de8f9e4a42a18b3341b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 13 Jun 2024 18:54:14 +0200 Subject: [PATCH 3/6] try CI fix --- ggml-cuda/mmq.cuh | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index c454f3f0ac60d..9233c70aad2f8 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -194,7 +194,9 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -328,7 +330,9 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( half2 dmA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -485,7 +489,9 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -635,7 +641,9 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( half2 dmA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -762,7 +770,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -886,7 +896,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[2]; float dA[mma_C::ne/2][2]; @@ -1071,7 +1083,9 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( const float * y_df = (const float *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1234,7 +1248,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1419,7 +1435,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1599,7 +1617,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const float * y_df = (const float *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[4]; int scA[mma_C::ne/2][4]; @@ -1702,7 +1722,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri typedef mma_int_C_I16J8 mma_C; const int i0 = threadIdx.y*mma_C::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { From 80ba2aef4ab4f8b1126e9f79a1167b3a48640034 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 13 Jun 2024 20:22:47 +0200 Subject: [PATCH 4/6] try CI fix --- ggml-cuda/mmq.cuh | 68 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 9233c70aad2f8..c9019242b0f9c 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -180,6 +180,7 @@ template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -194,9 +195,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -239,6 +238,10 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q4_1( @@ -317,6 +320,7 @@ template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -330,9 +334,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( half2 dmA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -376,6 +378,10 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_0( @@ -475,6 +481,7 @@ template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -489,9 +496,7 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -533,6 +538,10 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_1( @@ -628,6 +637,7 @@ template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -641,9 +651,7 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( half2 dmA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -686,6 +694,10 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q8_0( @@ -756,6 +768,7 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -770,9 +783,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -814,6 +825,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( @@ -887,6 +902,7 @@ template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; @@ -896,9 +912,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE mma_A A[2]; float dA[mma_C::ne/2][2]; @@ -961,6 +975,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*dA[l/2][0] + C[1].x[l]*dA[l/2][1] + mA[l/2][0]*sB[l%2][0] + mA[l/2][1]*sB[l%2][1])*dB[l%2]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( @@ -1073,6 +1091,7 @@ template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; @@ -1083,9 +1102,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( const float * y_df = (const float *) y; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1150,6 +1167,10 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q4_K( @@ -1239,6 +1260,7 @@ template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -1248,9 +1270,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1324,6 +1344,10 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_K( @@ -1426,6 +1450,7 @@ template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -1435,9 +1460,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1511,6 +1534,10 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q6_K( @@ -1607,6 +1634,7 @@ template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; @@ -1692,6 +1720,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template From bff3a20944a851fe443ed8cc326e19861f6d0cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 14 Jun 2024 16:16:44 +0200 Subject: [PATCH 5/6] fix data race --- ggml-cuda/mmq.cuh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index c9019242b0f9c..774083249abc4 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -850,16 +850,18 @@ template static __device__ __forceinlin const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx); - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = 0; - #pragma unroll - for (int l = 0; l < QR3_K; ++l) { + for (int l = 0; l < QR2_K; ++l) { const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4; int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4)); x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE); + if (kqsx % QR2_K != 0) { + continue; + } + x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k; } @@ -1011,6 +1013,10 @@ template static __device__ __forceinlin int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2)); x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE); + if (kqsx % 2 != 0) { + continue; + } + x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k; } } From 1d9dd480ff6c2d2c59e9f33f87fc01325d3a4ff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 14 Jun 2024 17:43:33 +0200 Subject: [PATCH 6/6] rever q2_K precision related changes --- ggml-cuda/mmq.cuh | 47 +++++++++++++++++------------------ ggml-cuda/quantize.cu | 58 +++++++++++-------------------------------- ggml-cuda/vecdotq.cuh | 13 +++++----- 3 files changed, 43 insertions(+), 75 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 774083249abc4..6d57974fb4e7c 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -882,8 +882,8 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -895,7 +895,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], - &x_dm[i*(WARP_SIZE + 1) + k0], y_ds[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); + &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]); } } } @@ -911,7 +911,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( typedef mma_int_C_I16J8 mma_C; const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; + const float * y_df = (const float *) y; const int i0 = threadIdx.y*mma_A::I; static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); @@ -944,10 +944,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) { - mma_C C[2]; + mma_C Cd[2]; + mma_C Cm[2]; mma_B B[2]; float dB[mma_C::ne/2]; - float sB[mma_C::ne/2][2]; #pragma unroll for (int l = 0; l < mma_B::ne; ++l) { @@ -961,20 +961,21 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - const half2 tmp = y_ds[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; - dB[l] = __low2float(tmp); - - const int8_t * sBi = (const int8_t *) &tmp.y; - sB[l][0] = (127.0f/8.0f)*sBi[0]; - sB[l][1] = (127.0f/8.0f)*sBi[1]; + dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)]; } - C[0].mma_K4(A[0], B[0]); - C[1].mma_K4(A[1], B[1]); + Cd[0].mma_K4(A[0], B[0]); + Cd[1].mma_K4(A[1], B[1]); + + mma_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + Cm[0].mma_K4(A1, B[0]); + Cm[1].mma_K4(A1, B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*dA[l/2][0] + C[1].x[l]*dA[l/2][1] + mA[l/2][0]*sB[l%2][0] + mA[l/2][1]*sB[l%2][1])*dB[l%2]; + sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2]; } } #else @@ -1870,26 +1871,24 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; -static int mmq_need_sum(const ggml_type type_x) { +static bool mmq_need_sum(const ggml_type type_x) { switch (type_x) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: - return 1; + return true; case GGML_TYPE_Q5_0: - return 0; + return false; case GGML_TYPE_Q5_1: - return 1; + return true; case GGML_TYPE_Q8_0: - return 0; case GGML_TYPE_Q2_K: - return 2; case GGML_TYPE_Q3_K: - return 0; + return false; case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - return 1; + return true; case GGML_TYPE_Q6_K: - return 0; + return false; default: GGML_ASSERT(false); break; diff --git a/ggml-cuda/quantize.cu b/ggml-cuda/quantize.cu index 8d61d8bd625f5..b4678682238d3 100644 --- a/ggml-cuda/quantize.cu +++ b/ggml-cuda/quantize.cu @@ -1,5 +1,4 @@ #include "quantize.cuh" -#include #include static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { @@ -38,7 +37,7 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -template +template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { @@ -61,48 +60,24 @@ static __global__ void quantize_mmq_q8_1( amax = warp_reduce_max(amax); + float sum; + if (need_sum) { + sum = warp_reduce_sum(xi); + } + const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; - static_assert(need_sum >= 0 && need_sum <= 2, "Invalid need_sum value."); - if (need_sum == 0) { - if (iqs % QK8_1 != 0) { - return; - } - - ((float *) y[ib].ds)[iqs/QK8_1] = d; - } else if (need_sum == 1) { - const float sum = warp_reduce_sum(xi); - - if (iqs % QK8_1 != 0) { - return; - } + if (iqs % QK8_1 != 0) { + return; + } + if (need_sum) { y[ib].ds[iqs/QK8_1] = make_half2(d, sum); } else { - float sum = xi; - - // Calculate sum per 16 values: -#pragma unroll - for (int mask = 8; mask > 0; mask >>= 1) { - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); - } - - if (iqs % (QK8_1/2) != 0) { - return; - } - - int8_t * si = (int8_t *) &y[ib].ds[iqs/QK8_1].y; - const int tmp = roundf(amax == 0.0f ? 0.0f : -8*sum/amax); - si[(iqs % QK8_1)/(QK8_1/2)] = min(tmp, 127); - - if (iqs % QK8_1 != 0) { - return; - } - - reinterpret_cast(y[ib].ds[iqs/QK8_1].x) = d; + ((float *) y[ib].ds)[iqs/QK8_1] = d; } } @@ -129,14 +104,9 @@ void quantize_mmq_q8_1_cuda( const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, kx1, channels); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - const int need_sum = mmq_need_sum(type_x); - if (need_sum == 0) { - quantize_mmq_q8_1<0><<>>(x, vy, kx0, kx1, kx0_padded); - } else if (need_sum == 1) { - quantize_mmq_q8_1<1><<>>(x, vy, kx0, kx1, kx0_padded); - } else if (need_sum == 2) { - quantize_mmq_q8_1<2><<>>(x, vy, kx0, kx1, kx0_padded); + if (mmq_need_sum(type_x)) { + quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); } else { - GGML_ASSERT(false); + quantize_mmq_q8_1<<>>(x, vy, kx0, kx1, kx0_padded); } } diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index 6bf4d6b7aa160..3b12d656616be 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -265,32 +265,31 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( // contiguous u/y values static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const half2 & ds8) { + const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; float sumf_m = 0.0f; - const float d8 = __low2float(ds8); - const int8_t * s8i = (const int8_t *) &ds8.y; - #pragma unroll for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); int sumi_d = 0; + int sumi_m = 0; const int vi0 = v[i0/(QI8_1/2)]; #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; - sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product + sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product + sumi_m = __dp4a(0x01010101, u[i], sumi_m); } sumf_d += dm2f.x * sumi_d; - sumf_m += dm2f.y * s8i[i0/(QI8_1/2)]; + sumf_m += dm2f.y * sumi_m; } - return d8*(sumf_d + (127.0f/8.0f)*sumf_m); + return d8*(sumf_d - sumf_m); #else NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A