From ac439e9c73577801e0b1282d6843121ba3dd7f00 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Thu, 19 Dec 2024 22:21:10 +0000 Subject: [PATCH 1/4] add exllamav2 kernels --- CMakeLists.txt | 4 +- csrc/ops.h | 7 + csrc/quantization/exllamav2/compat.cuh | 84 +++ csrc/quantization/exllamav2/matrix_view.cuh | 164 +++++ csrc/quantization/exllamav2/q_gemm.cu | 159 +++++ csrc/quantization/exllamav2/q_gemm.cuh | 583 ++++++++++++++++++ csrc/quantization/exllamav2/q_matrix.cu | 383 ++++++++++++ csrc/quantization/exllamav2/q_matrix.cuh | 85 +++ csrc/quantization/exllamav2/quant/qdq_2.cuh | 91 +++ csrc/quantization/exllamav2/quant/qdq_3.cuh | 171 +++++ csrc/quantization/exllamav2/quant/qdq_4.cuh | 145 +++++ csrc/quantization/exllamav2/quant/qdq_5.cuh | 212 +++++++ csrc/quantization/exllamav2/quant/qdq_6.cuh | 52 ++ csrc/quantization/exllamav2/quant/qdq_8.cuh | 47 ++ .../quantization/exllamav2/quant/qdq_util.cuh | 76 +++ csrc/torch_bindings.cpp | 11 + 16 files changed, 2273 insertions(+), 1 deletion(-) create mode 100644 csrc/quantization/exllamav2/compat.cuh create mode 100644 csrc/quantization/exllamav2/matrix_view.cuh create mode 100644 csrc/quantization/exllamav2/q_gemm.cu create mode 100644 csrc/quantization/exllamav2/q_gemm.cuh create mode 100644 csrc/quantization/exllamav2/q_matrix.cu create mode 100644 csrc/quantization/exllamav2/q_matrix.cuh create mode 100644 csrc/quantization/exllamav2/quant/qdq_2.cuh create mode 100644 csrc/quantization/exllamav2/quant/qdq_3.cuh create mode 100644 csrc/quantization/exllamav2/quant/qdq_4.cuh create mode 100644 csrc/quantization/exllamav2/quant/qdq_5.cuh create mode 100644 csrc/quantization/exllamav2/quant/qdq_6.cuh create mode 100644 csrc/quantization/exllamav2/quant/qdq_8.cuh create mode 100644 csrc/quantization/exllamav2/quant/qdq_util.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index 83c8033434f3b..7394708725846 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -244,7 +244,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_compressor_entry.cu" - "csrc/cutlass_extensions/common.cpp") + "csrc/cutlass_extensions/common.cpp" + "csrc/quantization/exllamav2/q_matrix.cu" + "csrc/quantization/exllamav2/q_gemm.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/csrc/ops.h b/csrc/ops.h index 347c502845d8f..f1f9914c5686e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -200,6 +200,13 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, c10::optional const& scale_ub); +uintptr_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm, + torch::Tensor q_invperm, torch::Tensor q_scale, + torch::Tensor q_scale_max, torch::Tensor q_groups, + torch::Tensor q_group_map); + +torch::Tensor exl2_gemm(torch::Tensor a, uintptr_t b); + void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& C, diff --git a/csrc/quantization/exllamav2/compat.cuh b/csrc/quantization/exllamav2/compat.cuh new file mode 100644 index 0000000000000..17df06478774f --- /dev/null +++ b/csrc/quantization/exllamav2/compat.cuh @@ -0,0 +1,84 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _compat_cuh + #define _compat_cuh + +namespace vllm { +namespace exl2 { + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } while (assumed != old); +} + +// + + #if defined(__CUDA_ARCH__) || defined(USE_ROCM) + #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { + atomicAdd_half(address, val); +} + + #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { + atomicAdd_half2(address, val); +} + #endif + + #endif + #endif + +#endif + +} // namespace exl2 +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/exllamav2/matrix_view.cuh b/csrc/quantization/exllamav2/matrix_view.cuh new file mode 100644 index 0000000000000..49f1aee0d96d5 --- /dev/null +++ b/csrc/quantization/exllamav2/matrix_view.cuh @@ -0,0 +1,164 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "quant/qdq_util.cuh" + +namespace vllm { +namespace exl2 { + +class MatrixView_half { + public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw { + public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half* item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2*)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +} // namespace exl2 +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/exllamav2/q_gemm.cu b/csrc/quantization/exllamav2/q_gemm.cu new file mode 100644 index 0000000000000..b93e9ed9e9adc --- /dev/null +++ b/csrc/quantization/exllamav2/q_gemm.cu @@ -0,0 +1,159 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include +#include +#include +#include + +#include "q_matrix.cuh" +#include "matrix_view.cuh" +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" +#include "q_gemm.cuh" + +namespace vllm { +namespace exl2 { + +#define MAX_Q_GEMM_ROWS 32 +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) +#if defined(USE_ROCM) +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + #define hipblasHgemm __compat_hipblasHgemm +#endif +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +void gemm_half_q_half_cuda_part(const half* a, QMatrix* b, half* c, int size_m, + int size_n, int size_k, int m_count, + bool clear) { + { + dim3 blockDim, gridDim; + blockDim.x = EXL2_BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(b->height, EXL2_BLOCK_KN_SIZE); + + fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + kernel<<>>( + a, b->cuda_q_weight, b->cuda_q_scale, b->cuda_q_scale_max, c, size_m, + size_n, size_k, b->height, b->groups, b->cuda_q_group_map, + b->cuda_q_perm, b->rows_8, b->rows_6, b->rows_5, b->rows_4, b->rows_3, + b->rows_2, clear); + } +} + +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, + QMatrix* b, half* c, int size_m, int size_n, + int size_k, bool clear, half* temp_dq) { + if (size_m > MAX_Q_GEMM_ROWS) { + // Reconstruct FP16 matrix, then cuBLAS + b->reconstruct(temp_dq); + + // cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); + + const half alpha = __float2half(1.0f); + const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else { + // Quantized matmul + + int block_m_size_max = EXL2_BLOCK_M_SIZE_MAX; + int max_chunks = size_m / block_m_size_max; + int last_chunk = max_chunks * block_m_size_max; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, + block_m_size_max, clear); + } + + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, clear); + } + } +} + +} // namespace exl2 +} // namespace vllm + +torch::Tensor exl2_gemm(torch::Tensor a, uintptr_t b) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + vllm::exl2::QMatrix* qm = reinterpret_cast(b); + + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), qm->width}, options); + at::Tensor temp_dq; + if (c.size(0) > MAX_Q_GEMM_ROWS) { + temp_dq = torch::zeros({a.size(1), qm->width}, options); + } + + vllm::exl2::gemm_half_q_half_cuda( + at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(), qm, + (half*)c.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + true, c.size(0) > MAX_Q_GEMM_ROWS ? (half*)temp_dq.data_ptr() : NULL); + return c; +} + +uintptr_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm, + torch::Tensor q_invperm, torch::Tensor q_scale, + torch::Tensor q_scale_max, torch::Tensor q_groups, + torch::Tensor q_group_map) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + int device = q_weight.device().index(); + int width = q_weight.size(1); + int groups = q_scale.size(0); + int height = q_perm.size(0); + + vllm::exl2::QMatrix* m = new vllm::exl2::QMatrix( + device, height, width, groups, (uint32_t*)q_weight.data_ptr(), + (uint16_t*)q_perm.data_ptr(), (uint16_t*)q_invperm.data_ptr(), + (uint32_t*)q_scale.data_ptr(), (half*)q_scale_max.data_ptr(), + (uint16_t*)q_groups.data_ptr(), (uint16_t*)q_group_map.data_ptr()); + return reinterpret_cast(m); +} \ No newline at end of file diff --git a/csrc/quantization/exllamav2/q_gemm.cuh b/csrc/quantization/exllamav2/q_gemm.cuh new file mode 100644 index 0000000000000..b93e4aac59ca6 --- /dev/null +++ b/csrc/quantization/exllamav2/q_gemm.cuh @@ -0,0 +1,583 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "compat.cuh" + +namespace vllm { +namespace exl2 { + +#define MAX_Q_GEMM_WEIGHTS 4 +#define EXL2_BLOCK_KN_SIZE 64 +#define EXL2_BLOCK_M_SIZE_MAX 8 +#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32) + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, + const half g_result, + const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +typedef void (*fp_gemm_half_q_half_kernel)( + const half*, const uint32_t*, const uint32_t*, const half*, half*, + const int, const int, const int, const int, const int, const uint16_t*, + const uint16_t*, const int, const int, const int, const int, const int, + const int, const bool); + +template +__global__ void gemm_half_q_half_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int height, + const int groups, const uint16_t* __restrict__ b_q_group_map, + const uint16_t* __restrict__ b_q_perm, const int rows_8, const int rows_6, + const int rows_5, const int rows_4, const int rows_3, const int rows_2, + const bool clear) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE; + + int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, height); + int n = offset_n + t * 4; + + // Read weights + + half_uint16 weights[MAX_Q_GEMM_WEIGHTS]; + + // Preload block_a + + __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + half a0 = a_ptr[b_q_perm[offset_k + t]]; + // half a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Clear + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + // int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + + // if (offset_m == 0 && t == 0) + // DBGI2(offset_k, group); + + // Preload scales + + half scales[EXL2_MAX_GROUPS_IN_BLOCK][4]; + + // int groups_in_block = DIVIDE((end_k - offset_k), groupsize); + int temp_k = offset_k; + for (int g = 0; temp_k < end_k; g++) { + int qscales[4]; + b_q_scale_.item4(qscales, group + g, n); + qscales[0]++; + qscales[1]++; + qscales[2]++; + qscales[3]++; + half maxscale = b_q_scale_max[group + g]; + scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale); + scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale); + scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale); + scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale); + temp_k += b_q_group_map[temp_k * 2 + 1]; + } + + // a, b offset + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = EXL2_BLOCK_KN_SIZE; + + // Initial group + + int scales_idx = 0; + half qs_h0 = scales[scales_idx][0]; + half qs_h1 = scales[scales_idx][1]; + half qs_h2 = scales[scales_idx][2]; + half qs_h3 = scales[scales_idx][3]; + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + // Column result + + half block_c[m_count][4] = {}; + + // Dequantize groups + + int k = offset_k; + + while (k < rows_8 && k < end_k) { + if (k == nextgroup) { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = + dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = + dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = + dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_6 && k < end_k) { + if (k == nextgroup) { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + +#pragma unroll + for (int j = 0; j < 2; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][8]; + dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n); + dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n); + dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n); + dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = + dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = + dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = + dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 16; + } + k += 32; + } + + while (k < rows_5 && k < end_k) { + if (k == nextgroup) { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[5]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[3] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[4] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, + load_int4[3].x, load_int4[4].x, dq[0], size_n); + dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, + load_int4[3].y, load_int4[4].y, dq[1], size_n); + dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, + load_int4[3].z, load_int4[4].z, dq[2], size_n); + dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, + load_int4[3].w, load_int4[4].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = + dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = + dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = + dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + + k += 32; + } + + while (k < rows_4 && k < end_k) { + if (k == nextgroup) { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[1]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_4bit_8(load_int4[0].x, dq[0], size_n); + dequant_4bit_8(load_int4[0].y, dq[1], size_n); + dequant_4bit_8(load_int4[0].z, dq[2], size_n); + dequant_4bit_8(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = + dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = + dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = + dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 8; + } + k += 32; + } + + while (k < rows_3 && k < end_k) { + if (k == nextgroup) { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = + dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = + dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = + dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + a_ptr += 32; + } + k += 32; + } + + while (k < rows_2 && k < end_k) { + if (k == nextgroup) { + group++; + scales_idx++; + qs_h0 = scales[scales_idx][0]; + qs_h1 = scales[scales_idx][1]; + qs_h2 = scales[scales_idx][2]; + qs_h3 = scales[scales_idx][3]; + nextgroup += b_q_group_map[k * 2 + 1]; + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[1]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][8]; + dequant_2bit_16(load_int4[0].x, dq[0], size_n); + dequant_2bit_16(load_int4[0].y, dq[1], size_n); + dequant_2bit_16(load_int4[0].z, dq[2], size_n); + dequant_2bit_16(load_int4[0].w, dq[3], size_n); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0); + block_c[m][1] = + dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1); + block_c[m][2] = + dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2); + block_c[m][3] = + dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3); + } + + a_ptr += 16; + } + k += 16; + } + + // Accumulate column sums in c + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + // *out = result01; + // *(out + 1) = result23; + } +} + +struct map_m_count_exl2 { + static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel( + const int m_count) { +#if EXL2_BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_kernel<1>; +#endif +#if EXL2_BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_kernel<2>; +#endif +#if EXL2_BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_kernel<3>; +#endif +#if EXL2_BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_kernel<4>; +#endif +#if EXL2_BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_kernel<5>; +#endif +#if EXL2_BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_kernel<6>; +#endif +#if EXL2_BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_kernel<7>; +#endif +#if EXL2_BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_kernel<8>; +#endif + return NULL; + } +}; + +fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count) { + return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count); +} + +} // namespace exl2 +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/exllamav2/q_matrix.cu b/csrc/quantization/exllamav2/q_matrix.cu new file mode 100644 index 0000000000000..198786b2b0b0a --- /dev/null +++ b/csrc/quantization/exllamav2/q_matrix.cu @@ -0,0 +1,383 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include +#include +#include +#include + +#include "q_matrix.cuh" +#include "matrix_view.cuh" + +#include "quant/qdq_2.cuh" +#include "quant/qdq_3.cuh" +#include "quant/qdq_4.cuh" +#include "quant/qdq_5.cuh" +#include "quant/qdq_6.cuh" +#include "quant/qdq_8.cuh" + +namespace vllm { +namespace exl2 { + +#define BLOCK_KN_SIZE 128 + +#define THREADS_X 32 +#define THREADS_Y 32 + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +// Shuffle quantized data on load + +__global__ void shuffle_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n, + const int rows_8, const int rows_6, + const int rows_5, const int rows_4, + const int rows_3, const int rows_2) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < rows_8) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } + while (k < rows_6) { + shuffle_6bit_16(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 16; + } + while (k < rows_5) { + shuffle_5bit_32(b_ptr, size_n); + b_ptr += 5 * size_n; + k += 32; + } + while (k < rows_4) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } + while (k < rows_3) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } + while (k < rows_2) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } +} + +// QMatrix constructor + +QMatrix::QMatrix(const int _device, const int _height, const int _width, + const int _groups, + + uint32_t* _q_weight, uint16_t* _q_perm, uint16_t* _q_invperm, + uint32_t* _q_scale, half* _q_scale_max, uint16_t* _q_groups, + uint16_t* _q_group_map) + : device(_device), height(_height), width(_width), groups(_groups) { + cudaSetDevice(device); + + failed = false; + + cuda_q_weight = _q_weight; + cuda_q_perm = _q_perm; + cuda_q_invperm = _q_invperm; + cuda_q_scale = _q_scale; + cuda_q_scale_max = _q_scale_max; + cuda_q_groups = _q_groups; + cuda_q_group_map = _q_group_map; + + // Create group map + + rows_8 = 0; + rows_6 = 0; + rows_5 = 0; + rows_4 = 0; + rows_3 = 0; + rows_2 = 0; + + { + uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); + cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), + cudaMemcpyDeviceToHost); + + int row = 0; + for (int i = 0; i < groups; i++) { + int bits = cpu_q_groups[i * 2]; + + int rows; + if (i < groups - 1) { + int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1]; + rows = qrows * 32 / bits; + } else + rows = height - row; + + if (bits == 8) rows_8 += rows; + if (bits == 6) rows_6 += rows; + if (bits == 5) rows_5 += rows; + if (bits == 4) rows_4 += rows; + if (bits == 3) rows_3 += rows; + if (bits == 2) rows_2 += rows; + row += rows; + } + + free(cpu_q_groups); + + rows_6 += rows_8; + rows_5 += rows_6; + rows_4 += rows_5; + rows_3 += rows_4; + rows_2 += rows_3; + } + + // Shuffle quantized data + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + shuffle_kernel<<>>(cuda_q_weight, height, width, + rows_8, rows_6, rows_5, + rows_4, rows_3, rows_2); +} + +QMatrix::~QMatrix() {} + +// Reconstruct b[k,n] + +__global__ void reconstruct_kernel(const uint32_t* __restrict__ b_q_weight, + const uint16_t* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_q_scale, + const half* __restrict__ b_q_scale_max, + const uint16_t* __restrict__ b_q_group_map, + const int size_k, const int size_n, + // const int groupsize, + const int groups, half* __restrict__ b, + const int rows_8, const int rows_6, + const int rows_5, const int rows_4, + const int rows_3, const int rows_2) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x; + + // Preload remapping table + + int t = threadIdx.x; + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + + // Column + + int n = offset_n + t; + if (n >= size_n) return; + + // Find initial group + + // int group = offset_k / groupsize; + int group = b_q_group_map[offset_k * 2]; + + int pre_rows_8 = min(rows_8, offset_k); + int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; + int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0; + int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0; + int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0; + int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0; + int qk = 0; + qk += pre_rows_8 / 32 * 8; + qk += pre_rows_6 / 32 * 6; + qk += pre_rows_5 / 32 * 5; + qk += pre_rows_4 / 32 * 4; + qk += pre_rows_3 / 32 * 3; + qk += pre_rows_2 / 32 * 2; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + half2 qs_h2 = __halves2half2(qs_h, qs_h); + int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1]; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + int k = offset_k; + int lk = 0; + + __syncthreads(); + + while (k < rows_8 && k < end_k) { + if (k == nextgroup) { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 4; p++) { + half2 dq[4]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + dequant_8bit_8(q_0, q_1, dq, size_n); + for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*)dq; + for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_6 && k < end_k) { + if (k == nextgroup) { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 2; p++) { + half2 dq[8]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + dequant_6bit_16(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*)dq; + for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_5 && k < end_k) { + if (k == nextgroup) { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) { + half2 dq[16]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + uint32_t q_3 = *b_ptr; + b_ptr += size_n; + uint32_t q_4 = *b_ptr; + b_ptr += size_n; + dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n); + for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*)dq; + for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_4 && k < end_k) { + if (k == nextgroup) { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 4; p++) { + half2 dq[4]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + dequant_4bit_8(q_0, dq, size_n); + for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*)dq; + for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_3 && k < end_k) { + if (k == nextgroup) { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) { + half2 dq[16]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + uint32_t q_1 = *b_ptr; + b_ptr += size_n; + uint32_t q_2 = *b_ptr; + b_ptr += size_n; + dequant_3bit_32(q_0, q_1, q_2, dq, size_n); + for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*)dq; + for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 32; + } + + while (k < rows_2 && k < end_k) { + if (k == nextgroup) { + group++; + qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); + nextgroup += b_q_group_map[k * 2 + 1]; + qs_h2 = __halves2half2(qs_h, qs_h); + } + for (int p = 0; p < 1; p++) { + half2 dq[8]; + uint32_t q_0 = *b_ptr; + b_ptr += size_n; + dequant_2bit_16(q_0, dq, size_n); + for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2); + half* dqh = (half*)dq; + for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); + } + k += 16; + } +} + +void QMatrix::reconstruct(half* out) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + + { + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_kernel<<>>( + cuda_q_weight, cuda_q_perm, cuda_q_scale, cuda_q_scale_max, + cuda_q_group_map, height, width, + // groupsize, + groups, out, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2); + } +} + +} // namespace exl2 +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/exllamav2/q_matrix.cuh b/csrc/quantization/exllamav2/q_matrix.cuh new file mode 100644 index 0000000000000..3734e51fea97c --- /dev/null +++ b/csrc/quantization/exllamav2/q_matrix.cuh @@ -0,0 +1,85 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _q_matrix_cuh +#define _q_matrix_cuh + +#include +#include +#include +#include + +namespace vllm { +namespace exl2 { + +#define MAX_SUPERGROUPS 16 + +class QMatrix { + public: + int device; + bool is_gptq; + + int height; + int width; + int groups; + int gptq_groupsize; + + int rows_8; + int rows_6; + int rows_5; + int rows_4; + int rows_3; + int rows_2; + + uint32_t* cuda_q_weight = NULL; + uint16_t* cuda_q_perm = NULL; + uint16_t* cuda_q_invperm = NULL; + uint32_t* cuda_q_scale = NULL; + half* cuda_q_scale_max = NULL; + uint16_t* cuda_q_groups = NULL; + uint16_t* cuda_q_group_map = NULL; + uint32_t* cuda_gptq_qzeros = NULL; + half* cuda_gptq_scales = NULL; + + half* temp_dq; + + bool failed; + + QMatrix(const int _device, const int _height, const int _width, + const int _groups, + + uint32_t* _q_weight, uint16_t* _q_perm, uint16_t* _q_invperm, + uint32_t* _q_scale, half* _q_scale_max, uint16_t* _q_groups, + uint16_t* _q_group_map); + + ~QMatrix(); + + void reconstruct(half* out); + bool make_sequential(const uint32_t* cpu_g_idx); + + private: +}; + +} // namespace exl2 +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/exllamav2/quant/qdq_2.cuh b/csrc/quantization/exllamav2/quant/qdq_2.cuh new file mode 100644 index 0000000000000..e506810727754 --- /dev/null +++ b/csrc/quantization/exllamav2/quant/qdq_2.cuh @@ -0,0 +1,91 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace exl2 { +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, + half2 (&dq)[8], int stride) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 2.0f); + const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z4 = __halves2half2(z4_, z4_); + const half2 z16 = __halves2half2(z16_, z16_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +} // namespace exl2 +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/exllamav2/quant/qdq_3.cuh b/csrc/quantization/exllamav2/quant/qdq_3.cuh new file mode 100644 index 0000000000000..334b7357fc7fb --- /dev/null +++ b/csrc/quantization/exllamav2/quant/qdq_3.cuh @@ -0,0 +1,171 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace exl2 { + +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; + qa >>= 6; + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; + qb >>= 6; + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; + qc >>= 6; + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], int stride) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half z1_ = __float2half_rn(-1024.0f - 4.0f); + const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f); + const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +} // namespace exl2 +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/exllamav2/quant/qdq_4.cuh b/csrc/quantization/exllamav2/quant/qdq_4.cuh new file mode 100644 index 0000000000000..8c16000df581d --- /dev/null +++ b/csrc/quantization/exllamav2/quant/qdq_4.cuh @@ -0,0 +1,145 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace exl2 { + +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, + half2 (&dq)[4], int stride) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( + const uint32_t zero, const half scale, half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, + half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | + c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | + c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | + c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | + c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} + +} // namespace exl2 +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/exllamav2/quant/qdq_5.cuh b/csrc/quantization/exllamav2/quant/qdq_5.cuh new file mode 100644 index 0000000000000..be62fd238c23b --- /dev/null +++ b/csrc/quantization/exllamav2/quant/qdq_5.cuh @@ -0,0 +1,212 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _qdq_5_cuh +#define _qdq_5_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace exl2 { + +// Permutation: +// +// v5555533 33311111 u4444422 22200000 (u, v lsb) +// vbbbbb99 99977777 uaaaaa88 88866666 +// vhhhhhff fffddddd ugggggee eeeccccc +// vnnnnnll llljjjjj ummmmmkk kkkiiiii +// vtttttrr rrrppppp usssssqq qqqooooo + +__forceinline__ __device__ void shuffle_5bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + uint32_t qd = q[3 * stride]; + uint32_t qe = q[4 * stride]; + + // qa: 66555554 44443333 32222211 11100000 + // qb: ccccbbbb baaaaa99 99988888 77777666 + // qc: jiiiiihh hhhggggg fffffeee eedddddc + // qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj + // qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp + + uint32_t qf = qe >> 22; + qe <<= 8; + qe |= qd >> 24; + qd <<= 6; + qd |= qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: 555554 44443333 32222211 11100000 + // qb: bbbbba aaaa9999 98888877 77766666 + // qc: hhhhhg ggggffff feeeeedd dddccccc + // qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii + // qe: ttttts ssssrrrr rqqqqqpp pppooooo + // qf: vv vvvuuuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + uint32_t zd = 0; + uint32_t ze = 0; + + for (int i = 0; i < 3; i++) { + uint32_t t0 = qa & 0x1f; + uint32_t t1 = (qa & 0x3e0) >> 5; + qa >>= 10; + za |= (t0 << (i * 5)); + za |= (t1 << (i * 5 + 16)); + } + for (int i = 0; i < 3; i++) { + uint32_t t0 = qb & 0x1f; + uint32_t t1 = (qb & 0x3e0) >> 5; + qb >>= 10; + zb |= (t0 << (i * 5)); + zb |= (t1 << (i * 5 + 16)); + } + for (int i = 0; i < 3; i++) { + uint32_t t0 = qc & 0x1f; + uint32_t t1 = (qc & 0x3e0) >> 5; + qc >>= 10; + zc |= (t0 << (i * 5)); + zc |= (t1 << (i * 5 + 16)); + } + for (int i = 0; i < 3; i++) { + uint32_t t0 = qd & 0x1f; + uint32_t t1 = (qd & 0x3e0) >> 5; + qd >>= 10; + zd |= (t0 << (i * 5)); + zd |= (t1 << (i * 5 + 16)); + } + for (int i = 0; i < 3; i++) { + uint32_t t0 = qe & 0x1f; + uint32_t t1 = (qe & 0x3e0) >> 5; + qe >>= 10; + ze |= (t0 << (i * 5)); + ze |= (t1 << (i * 5 + 16)); + } + + // za: 5555533 33311111 4444422 22200000 + // zb: bbbbb99 99977777 aaaaa88 88866666 + // zc: hhhhhff fffddddd gggggee eeeccccc + // zd: nnnnnll llljjjjj mmmmmkk kkkiiiii + // ze: tttttrr rrrppppp sssssqq qqqooooo + // qf: vv vvvuuuuu + + za |= ((qf & 0x001) >> 0) << 15; + zb |= ((qf & 0x002) >> 1) << 15; + zc |= ((qf & 0x004) >> 2) << 15; + zd |= ((qf & 0x008) >> 3) << 15; + ze |= ((qf & 0x010) >> 4) << 15; + za |= ((qf & 0x020) >> 5) << 31; + zb |= ((qf & 0x040) >> 6) << 31; + zc |= ((qf & 0x080) >> 7) << 31; + zd |= ((qf & 0x100) >> 8) << 31; + ze |= ((qf & 0x200) >> 9) << 31; + + // za: v5555533 33311111 u4444422 22200000 (u, v lsb) + // zb: vbbbbb99 99977777 uaaaaa88 88866666 + // zc: vhhhhhff fffddddd ugggggee eeeccccc + // zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii + // ze: vtttttrr rrrppppp usssssqq qqqooooo + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; + q[3 * stride] = zd; + q[4 * stride] = ze; +} + +__forceinline__ __device__ void dequant_5bit_32( + const uint32_t q_0, const uint32_t q_1, const uint32_t q_2, + const uint32_t q_3, const uint32_t q_4, half2 (&dq)[16], int stride) { + const uint32_t c0 = 0x64006400; + const half y32_ = __float2half_rn(1.0f / 32.0f); + const half2 y32 = __halves2half2(y32_, y32_); + const half z1_ = __float2half_rn(-1024.0f - 16.0f); + const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z32 = __halves2half2(z32_, z32_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + uint32_t qd = q_3; + uint32_t qe = q_4; + + half2_uint32 q0((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024 + qa >>= 10; + half2_uint32 q2((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024 + qa >>= 5; + qa &= 0x00010001; + half2_uint32 q3((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024 + half2_uint32 q4((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024 + qb >>= 10; + half2_uint32 q5((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024 + qb >>= 4; + qb &= 0x00020002; + half2_uint32 q6((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024 + half2_uint32 q7((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024 + qc >>= 10; + half2_uint32 q8((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024 + qc >>= 3; + qc &= 0x00040004; + half2_uint32 q9((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024 + half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024 + qd >>= 10; + half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024 + qd >>= 2; + qd &= 0x00080008; + half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024 + qe >>= 10; + half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024 + qe >>= 1; + qe &= 0x00100010; + half2_uint32 q15((qa | qb | qc | qd | qe) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y32, z32); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hadd2(q3.as_half2, z1); + dq[4] = __hfma2(q4.as_half2, y32, z32); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hadd2(q6.as_half2, z1); + dq[7] = __hfma2(q7.as_half2, y32, z32); + dq[8] = __hadd2(q8.as_half2, z1); + dq[9] = __hadd2(q9.as_half2, z1); + dq[10] = __hfma2(q10.as_half2, y32, z32); + dq[11] = __hadd2(q11.as_half2, z1); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y32, z32); + dq[14] = __hadd2(q14.as_half2, z1); + dq[15] = __hadd2(q15.as_half2, z1); +} + +} // namespace exl2 +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/exllamav2/quant/qdq_6.cuh b/csrc/quantization/exllamav2/quant/qdq_6.cuh new file mode 100644 index 0000000000000..7d02d7a610551 --- /dev/null +++ b/csrc/quantization/exllamav2/quant/qdq_6.cuh @@ -0,0 +1,52 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _qdq_6_cuh +#define _qdq_6_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace exl2 { + +__forceinline__ __device__ void shuffle_6bit_16(uint32_t* q, int stride) {} + +__forceinline__ __device__ void dequant_6bit_16(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[8], int stride) { + half dqh[16]; + for (int i = 0; i < 5; i++) dqh[i] = dq_ns(exb(q_0, i * 6, 0x3f), 32); + dqh[5] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32); + for (int i = 0; i < 4; i++) dqh[6 + i] = dq_ns(exb(q_1, i * 6 + 4, 0x3f), 32); + dqh[10] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32); + for (int i = 0; i < 5; i++) + dqh[11 + i] = dq_ns(exb(q_2, i * 6 + 2, 0x3f), 32); + + for (int i = 0; i < 8; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +} // namespace exl2 +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/exllamav2/quant/qdq_8.cuh b/csrc/quantization/exllamav2/quant/qdq_8.cuh new file mode 100644 index 0000000000000..cc2642d7e24f2 --- /dev/null +++ b/csrc/quantization/exllamav2/quant/qdq_8.cuh @@ -0,0 +1,47 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace exl2 { + +__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} + +__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], int stride) { + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), 128); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128); + + for (int i = 0; i < 4; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +} // namespace exl2 +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/exllamav2/quant/qdq_util.cuh b/csrc/quantization/exllamav2/quant/qdq_util.cuh new file mode 100644 index 0000000000000..461d015100c88 --- /dev/null +++ b/csrc/quantization/exllamav2/quant/qdq_util.cuh @@ -0,0 +1,76 @@ +/* + * Adapted from https://github.com/turboderp/exllamav2 + * Copyright (c) 2024 turboderp + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +namespace vllm { +namespace exl2 { + +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} +}; + +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, + const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, + const int mask) { + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, + const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +} // namespace exl2 +} // namespace vllm +#endif \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 956258c1001d3..552a1ac4995a4 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -343,6 +343,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor a) -> bool"); ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry); + // ExLlamaV2 GEMM kernels + ops.def( + "make_q_matrix(Tensor q_weight, Tensor q_perm, Tensor q_invperm," + " Tensor q_scale, Tensor q_scale_max, Tensor q_groups," + " Tensor q_group_map) -> uintptr_t"); + ops.impl("make_q_matrix", torch::kCUDA, &make_q_matrix); + + ops.def("exl2_gemm(Tensor a, uintptr_t b) -> Tensor"); + ops.impl("exl2_gemm(Tensor a, uintptr_t b) -> Tensor", torch::kCUDA, + &exl2_gemm); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," From 491671b8f600b6fc77df6e0a9d038308c1265fa3 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 20 Dec 2024 00:24:51 +0000 Subject: [PATCH 2/4] fix compilation --- csrc/ops.h | 10 +++++----- csrc/quantization/exllamav2/q_gemm.cu | 24 +++++++++++++++++------- csrc/torch_bindings.cpp | 10 +++++----- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index f1f9914c5686e..09942311ba1a2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -200,12 +200,12 @@ void dynamic_per_token_scaled_fp8_quant( torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, c10::optional const& scale_ub); -uintptr_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm, - torch::Tensor q_invperm, torch::Tensor q_scale, - torch::Tensor q_scale_max, torch::Tensor q_groups, - torch::Tensor q_group_map); +int64_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm, + torch::Tensor q_invperm, torch::Tensor q_scale, + torch::Tensor q_scale_max, torch::Tensor q_groups, + torch::Tensor q_group_map); -torch::Tensor exl2_gemm(torch::Tensor a, uintptr_t b); +torch::Tensor exl2_gemm(torch::Tensor a, int64_t b); void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, const torch::Tensor& B, diff --git a/csrc/quantization/exllamav2/q_gemm.cu b/csrc/quantization/exllamav2/q_gemm.cu index b93e9ed9e9adc..f174f0d1e1534 100644 --- a/csrc/quantization/exllamav2/q_gemm.cu +++ b/csrc/quantization/exllamav2/q_gemm.cu @@ -119,9 +119,13 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, } // namespace exl2 } // namespace vllm -torch::Tensor exl2_gemm(torch::Tensor a, uintptr_t b) { +torch::Tensor exl2_gemm(torch::Tensor a, int64_t b) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - vllm::exl2::QMatrix* qm = reinterpret_cast(b); + if (b < 0) { + throw std::runtime_error("Invalid pointer value passed as int64_t"); + } + vllm::exl2::QMatrix* qm = + reinterpret_cast(static_cast(b)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); at::Tensor c = torch::empty({a.size(0), qm->width}, options); @@ -140,10 +144,10 @@ torch::Tensor exl2_gemm(torch::Tensor a, uintptr_t b) { return c; } -uintptr_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm, - torch::Tensor q_invperm, torch::Tensor q_scale, - torch::Tensor q_scale_max, torch::Tensor q_groups, - torch::Tensor q_group_map) { +int64_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm, + torch::Tensor q_invperm, torch::Tensor q_scale, + torch::Tensor q_scale_max, torch::Tensor q_groups, + torch::Tensor q_group_map) { const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); int device = q_weight.device().index(); int width = q_weight.size(1); @@ -155,5 +159,11 @@ uintptr_t make_q_matrix(torch::Tensor q_weight, torch::Tensor q_perm, (uint16_t*)q_perm.data_ptr(), (uint16_t*)q_invperm.data_ptr(), (uint32_t*)q_scale.data_ptr(), (half*)q_scale_max.data_ptr(), (uint16_t*)q_groups.data_ptr(), (uint16_t*)q_group_map.data_ptr()); - return reinterpret_cast(m); + + uintptr_t ptr_val = reinterpret_cast(m); + if (ptr_val > static_cast(std::numeric_limits::max())) { + delete m; + throw std::runtime_error("Pointer value too large for int64_t"); + } + return static_cast(ptr_val); } \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 552a1ac4995a4..5d4a7efbe7f21 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -345,13 +345,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // ExLlamaV2 GEMM kernels ops.def( - "make_q_matrix(Tensor q_weight, Tensor q_perm, Tensor q_invperm," + "exl2_make_q_matrix(Tensor q_weight, Tensor q_perm, Tensor q_invperm," " Tensor q_scale, Tensor q_scale_max, Tensor q_groups," - " Tensor q_group_map) -> uintptr_t"); - ops.impl("make_q_matrix", torch::kCUDA, &make_q_matrix); + " Tensor q_group_map) -> int64_t"); + ops.impl("exl2_make_q_matrix", torch::kCUDA, &make_q_matrix); - ops.def("exl2_gemm(Tensor a, uintptr_t b) -> Tensor"); - ops.impl("exl2_gemm(Tensor a, uintptr_t b) -> Tensor", torch::kCUDA, + ops.def("exl2_gemm(Tensor a, int64_t b) -> Tensor"); + ops.impl("exl2_gemm(Tensor a, int64_t b) -> Tensor", torch::kCUDA, &exl2_gemm); // Mamba selective scan kernel From 4eb9a09ff4bfbe8cda7cfb91a09ec96ab2505184 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 20 Dec 2024 00:35:27 +0000 Subject: [PATCH 3/4] `int64_t` -> `int` in registration --- csrc/torch_bindings.cpp | 7 +++---- vllm/_custom_ops.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 5d4a7efbe7f21..ec93624abe40f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -347,12 +347,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "exl2_make_q_matrix(Tensor q_weight, Tensor q_perm, Tensor q_invperm," " Tensor q_scale, Tensor q_scale_max, Tensor q_groups," - " Tensor q_group_map) -> int64_t"); + " Tensor q_group_map) -> int"); ops.impl("exl2_make_q_matrix", torch::kCUDA, &make_q_matrix); - ops.def("exl2_gemm(Tensor a, int64_t b) -> Tensor"); - ops.impl("exl2_gemm(Tensor a, int64_t b) -> Tensor", torch::kCUDA, - &exl2_gemm); + ops.def("exl2_gemm(Tensor a, int b) -> Tensor"); + ops.impl("exl2_gemm", torch::kCUDA, &exl2_gemm); // Mamba selective scan kernel ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 19f31b8ec419d..cc57c3354d22d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -335,6 +335,20 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) +# exllamav2 +def exl2_make_q_matrix(q_weight: torch.Tensor, q_perm: torch.Tensor, + q_invperm: torch.Tensor, q_scale: torch.Tensor, + q_scale_max: torch.Tensor, q_groups: torch.Tensor, + q_group_map: torch.Tensor) -> int: + return torch.ops._C.exl2_make_q_matrix(q_weight, q_perm, q_invperm, + q_scale, q_scale_max, q_groups, + q_group_map) + + +def exl2_gemm(a: torch.Tensor, b: int) -> torch.Tensor: + return torch.ops._C.exl2_gemm(a, b) + + if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @register_fake("_C::gptq_marlin_24_gemm") From 5bb4bc8a34f8cbca653a7f37d95d1b7a5945d131 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 20 Dec 2024 01:53:35 +0000 Subject: [PATCH 4/4] add exl2 linear method --- .../layers/quantization/__init__.py | 3 + .../layers/quantization/exllamav2.py | 146 ++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 vllm/model_executor/layers/quantization/exllamav2.py diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index dd10c434f0752..36f0f9c5ec47c 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -26,6 +26,7 @@ "experts_int8", "neuron_quant", "ipex", + "exllamav2" ] @@ -41,6 +42,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) from .deepspeedfp import DeepSpeedFPConfig + from .exllamav2 import Exl2Config from .experts_int8 import ExpertsInt8Config from .fbgemm_fp8 import FBGEMMFp8Config from .fp8 import Fp8Config @@ -79,6 +81,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, + "exllamav2": Exl2Config, } return method_to_config[quantization] diff --git a/vllm/model_executor/layers/quantization/exllamav2.py b/vllm/model_executor/layers/quantization/exllamav2.py new file mode 100644 index 0000000000000..a64ee74105306 --- /dev/null +++ b/vllm/model_executor/layers/quantization/exllamav2.py @@ -0,0 +1,146 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +def make_group_map(q_groups, num_qrows): + gr = q_groups.tolist() + group_map = [] + num_groups = len(gr) // 2 + + for i in range(num_groups): + bits = gr[i * 2] + if i < num_groups - 1: + qrows = gr[i * 2 + 3] - gr[i * 2 + 1] + else: + qrows = num_qrows - gr[i * 2 + 1] + rows = qrows * 32 // bits + for j in range(rows): + group_map += [i] + group_map += [rows - j] + return torch.tensor(group_map, dtype=torch.short, device=q_groups.device) + + +class Exl2Config(QuantizationConfig): + """Config class for Exl2.""" + + def __repr__(self) -> str: + return "Exl2Config()" + + @classmethod + def get_name(cls) -> str: + return "exl2" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Exl2Config": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["Exl2LinearMethod"]: + if isinstance(layer, LinearBase): + return Exl2LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + def merge_weight(self) -> bool: + return False + + def quant_vocab(self) -> List[bool]: + return [False, True] + + def support_fused_moe(self) -> bool: + return False + + def rope_style(self) -> Optional[bool]: + return None + + +class Exl2LinearMethod(LinearMethodBase): + """Linear method for Exl2. + + Args: + quant_config: The Exl2 quantization config. + """ + + def __init__(self, quant_config: Exl2Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attr): + # The shape of weight is unknown until load state dict + # q_groups, q_invperm, q_scale, q_scale_max, q_weight, q_groups + layer.exllama_state = 0 + qweight = torch.nn.parameter.UninitializedParameter( + requires_grad=False) + set_weight_attrs(qweight, {"output_dim": 1, "ignore_warning": True}) + layer.register_parameter("q_weight", qweight) + qscale = torch.nn.parameter.UninitializedParameter(requires_grad=False) + set_weight_attrs( + qscale, { + "output_dim": 1, + "packed_dim": 1, + "pack_factor": 8, + "ignore_warning": True + }) + layer.register_parameter("q_scale", qscale) + for name in ["q_groups", "q_invperm", "q_scale_max"]: + fake_weight = torch.nn.parameter.UninitializedParameter( + requires_grad=False) + set_weight_attrs(fake_weight, {"ignore_warning": True}) + layer.register_parameter(name, fake_weight) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.q_weight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + + if layer.exllama_state == 0: + layer.q_scale_max /= 256 + layer.q_invperm = layer.q_invperm.short() + if not hasattr(layer, 'q_perm'): + layer.q_perm = torch.argsort(layer.q_invperm).to(torch.short) + if not hasattr(layer, 'q_group_map'): + layer.q_group_map = make_group_map(layer.q_groups, + layer.q_weight.shape[0]) + layer.q_matrix = ops.exl2_make_q_matrix( + layer.q_weight, + layer.q_perm, + layer.q_invperm, + layer.q_scale, + layer.q_scale_max, + layer.q_groups, + layer.q_group_map, + ) + layer.exllama_state = 1 + + output = ops.exl2_gemm(reshaped_x, layer.q_matrix) + + if bias is not None: + output.add_(bias) + return output.reshape(out_shape)