forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Kernel] Add Exllama as a backend for compressed-tensors (vllm-projec…
- Loading branch information
1 parent
56625f4
commit 15d8bc8
Showing
7 changed files
with
173 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
vllm/model_executor/layers/quantization/kernels/exllama.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
pack_quantized_values_into_int32) | ||
from vllm.model_executor.parameter import (BasevLLMParameter, | ||
permute_param_layout_) | ||
from vllm.scalar_type import scalar_types | ||
|
||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig | ||
|
||
|
||
class ExllamaLinearKernel(MPLinearKernel): | ||
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] | ||
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but | ||
# currently untested so not added to the list | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
return 60 | ||
|
||
@classmethod | ||
def can_implement(cls, | ||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: | ||
if c.has_g_idx and\ | ||
c.partition_weight_shape[0] != c.full_weight_shape[0]: | ||
return False, "Act reordering currently not supported by Exllama, "\ | ||
"when the input features are partitioned across "\ | ||
"devices" | ||
|
||
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: | ||
return False, "Output features must be a multiple of the pack " \ | ||
"factor (32 / num_bits) so that we can correctly " \ | ||
"pack the zero points" | ||
|
||
if c.act_type != torch.float16: | ||
return False, "Exllama only supports float16 activations" | ||
|
||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: | ||
return False, f"Quant type ({c.weight_type}) not supported by "\ | ||
"Exllama, supported types are: "\ | ||
f"{cls.SUPPORTED_QUANT_TYPES}" | ||
|
||
if c.full_weight_shape[0] % c.group_size != 0: | ||
return False, f"Group size ({c.group_size}) does not evenly divide"\ | ||
" the number of input features "\ | ||
f"({c.full_weight_shape[0]})" | ||
|
||
return True, None | ||
|
||
def process_weights_after_loading(self, layer: torch.nn.Module): | ||
c = self.config | ||
|
||
# For Exllama, we need to set a zero-point tensor if there is not one | ||
if not c.zero_points: | ||
self.w_zp_name = "qzeros" | ||
device = getattr(layer, self.w_q_name).device | ||
groups = c.partition_weight_shape[0] // c.group_size | ||
out_features = c.partition_weight_shape[1] | ||
|
||
if c.weight_type.has_bias(): | ||
# if the type has a bias we have to create a zeros tensor that | ||
# contains the bias values repeated for each group (-1 due to | ||
# a bug in the original GPTQ checkpoint format leading to | ||
# exllama kernel adding 1 to the zero points during inference) | ||
# Documentation of the bug can be found here: | ||
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format | ||
zeros = torch.full((groups, out_features), | ||
c.weight_type.bias - 1, | ||
dtype=torch.int32, | ||
device=device) | ||
else: | ||
raise NotImplementedError( | ||
"A 0 zero-point is not supported by Exllama due to " | ||
"a bug in the original GPTQ checkpoint format leading to " | ||
"exllama kernel adding 1 to the zero points during " | ||
"inference") | ||
zeros = pack_quantized_values_into_int32(zeros, | ||
c.weight_type, | ||
packed_dim=1) | ||
setattr(layer, self.w_zp_name, | ||
torch.nn.Parameter(zeros, requires_grad=False)) | ||
|
||
if c.has_g_idx: | ||
|
||
def transform_w_g_idx(x): | ||
# Exllama wants the permutation array instead of the group | ||
# indices | ||
return torch.argsort(x).to(torch.int) | ||
|
||
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) | ||
else: | ||
self.w_gidx_name = "g_idx" | ||
empty_g_idx = torch.nn.Parameter(torch.empty((0, ), | ||
dtype=torch.int, | ||
device=device), | ||
requires_grad=False) | ||
setattr(layer, self.w_gidx_name, empty_g_idx) | ||
|
||
def transform_w_q(x): | ||
assert isinstance(x, BasevLLMParameter) | ||
assert self.w_gidx_name is not None | ||
g_idx = getattr(layer, self.w_gidx_name) | ||
|
||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) | ||
x_cont = x.data.contiguous() | ||
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits) | ||
return x_cont | ||
|
||
def transform_w_s(x): | ||
assert isinstance(x, BasevLLMParameter) | ||
permute_param_layout_(x, input_dim=0, output_dim=1) | ||
x.data = x.data.contiguous() | ||
return x.to(dtype=c.act_type) | ||
|
||
# Repack weights and scales for Machete | ||
self._transform_param(layer, self.w_q_name, transform_w_q) | ||
self._transform_param(layer, self.w_s_name, transform_w_s) | ||
|
||
def apply_weights(self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
c = self.config | ||
|
||
x_2d = x.reshape(-1, x.shape[-1]) | ||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) | ||
|
||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) | ||
|
||
assert w_zp is not None, "Zero points are required by Exllama" | ||
assert w_g_idx is not None, "Group index is required by Exllama" | ||
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, | ||
c.weight_type.size_bits) | ||
|
||
if bias is not None: | ||
output.add_(bias) | ||
return output.reshape(out_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters