Skip to content

Commit

Permalink
upgrade exllama kernel (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 authored Jan 30, 2024
1 parent ebf0cee commit f525ca6
Show file tree
Hide file tree
Showing 66 changed files with 4,889 additions and 319 deletions.
2 changes: 2 additions & 0 deletions server/exllamav2_kernels/exllamav2_kernels/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _config_h

#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS

#define QMODE_2BIT 1
#define QMODE_3BIT 1
Expand All @@ -10,4 +11,5 @@
#define QMODE_6BIT 0
#define QMODE_8BIT 0


#endif
52 changes: 52 additions & 0 deletions server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "quantize_func.h"
#include "../cuda/quantize.cuh"

void quantize_range
(
torch::Tensor quant,
torch::Tensor scale,
torch::Tensor out_q,
float qzero,
float maxq,
torch::Tensor hessian_inv,
torch::Tensor weights,
torch::Tensor error,
int a,
int b
)
{
int columns = weights.size(1);
int hcolumns = hessian_inv.size(1);
TORCH_CHECK(hcolumns == weights.size(0), "H shape mismatch")

for (int c = a; c < b; c++)
{
fused_quantize_adjust_cuda
(
(const float*) weights.data_ptr(),
(float*) quant.data_ptr(),
(const float*) scale.data_ptr(),
out_q.device().is_meta() ? NULL : (uint16_t*) out_q.data_ptr(),
(const float*) hessian_inv.data_ptr(),
(float*) error.data_ptr(),
c, // row
hcolumns, // rows
columns,
qzero,
maxq
);

vv_mul_sub_cuda
(
((const float*) hessian_inv.data_ptr()) + c * hcolumns + c,
((const float*) error.data_ptr()) + c * columns,
((float*) weights.data_ptr()) + c * columns,
b - c,
columns
);
}

torch::Tensor x = hessian_inv.slice(0, a, b).slice(1, b).transpose(0, 1);
torch::Tensor y = error.slice(0, a, b);
weights.slice(0, b).addmm_(x, y, 1.0f, -1.0f);
}
25 changes: 25 additions & 0 deletions server/exllamav2_kernels/exllamav2_kernels/cpp/quantize_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef _quantize_func_h
#define _quantize_func_h

#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <ATen/cuda/CUDAContext.h>
#include <cstdint>
#include <cstdio>

void quantize_range
(
torch::Tensor quant,
torch::Tensor scale,
torch::Tensor out_q,
float qzero,
float maxq,
torch::Tensor hessian_inv,
torch::Tensor weights,
torch::Tensor error,
int a,
int b
);

#endif
Loading

0 comments on commit f525ca6

Please sign in to comment.