From 26a1a213de768e122e9843f421415adf86a20f24 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 17 Oct 2024 09:48:26 -0400 Subject: [PATCH] [Kernel] Add Exllama as a backend for compressed-tensors (#9395) Signed-off-by: qishuai --- vllm/envs.py | 9 ++ .../quantization/kernels/MPLinearKernel.py | 4 + .../layers/quantization/kernels/__init__.py | 8 +- .../layers/quantization/kernels/exllama.py | 140 ++++++++++++++++++ .../layers/quantization/kernels/machete.py | 14 +- .../layers/quantization/utils/quant_utils.py | 12 +- vllm/scalar_type.py | 2 + 7 files changed, 173 insertions(+), 16 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/exllama.py diff --git a/vllm/envs.py b/vllm/envs.py index 8b541e5b78c01..45a9999610f6a 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -66,6 +66,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 + VLLM_DISABLED_KERNELS: List[str] = [] def get_default_cache_root(): @@ -430,6 +431,14 @@ def get_default_config_root(): "VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1": lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0" ) == "1", + + # List of quantization kernels that should be disabled, used for testing + # and performance comparisons. Currently only affects MPLinearKernel + # selection + # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel) + "VLLM_DISABLED_KERNELS": + lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ + "VLLM_DISABLED_KERNELS"].split(","), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py index fe50c4930d043..b04612a9b00d9 100644 --- a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -42,6 +42,10 @@ def __init__(self, self.config = c self.w_q_name = w_q_param_name self.w_s_name = w_s_param_name + if c.zero_points: + assert w_zp_param_name is not None + if c.has_g_idx: + assert w_gidx_param_name is not None self.w_zp_name = w_zp_param_name self.w_gidx_name = w_gidx_param_name diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py index 47591c2aa644e..94a3dc2584d6b 100644 --- a/vllm/model_executor/layers/quantization/kernels/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -1,6 +1,8 @@ -import os from typing import List, Optional, Type +import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.exllama import ( + ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.machete import ( MacheteLinearKernel) from vllm.model_executor.layers.quantization.kernels.marlin import ( @@ -13,6 +15,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, MarlinLinearKernel, + ExllamaLinearKernel, ] @@ -45,8 +48,7 @@ def choose_mp_linear_kernel( failure_reasons = [] for kernel in _POSSIBLE_KERNELS: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ - .split(","): + if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: failure_reasons.append( f' {kernel.__name__} disabled by environment variable') continue diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py new file mode 100644 index 0000000000000..1d85d62ec83ee --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/exllama.py @@ -0,0 +1,140 @@ +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class ExllamaLinearKernel(MPLinearKernel): + SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] + # In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but + # currently untested so not added to the list + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Exllama, "\ + "when the input features are partitioned across "\ + "devices" + + if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: + return False, "Output features must be a multiple of the pack " \ + "factor (32 / num_bits) so that we can correctly " \ + "pack the zero points" + + if c.act_type != torch.float16: + return False, "Exllama only supports float16 activations" + + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Exllama, supported types are: "\ + f"{cls.SUPPORTED_QUANT_TYPES}" + + if c.full_weight_shape[0] % c.group_size != 0: + return False, f"Group size ({c.group_size}) does not evenly divide"\ + " the number of input features "\ + f"({c.full_weight_shape[0]})" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + # For Exllama, we need to set a zero-point tensor if there is not one + if not c.zero_points: + self.w_zp_name = "qzeros" + device = getattr(layer, self.w_q_name).device + groups = c.partition_weight_shape[0] // c.group_size + out_features = c.partition_weight_shape[1] + + if c.weight_type.has_bias(): + # if the type has a bias we have to create a zeros tensor that + # contains the bias values repeated for each group (-1 due to + # a bug in the original GPTQ checkpoint format leading to + # exllama kernel adding 1 to the zero points during inference) + # Documentation of the bug can be found here: + # https://garden.danieldk.eu/GPTQ-Checkpoint-Format + zeros = torch.full((groups, out_features), + c.weight_type.bias - 1, + dtype=torch.int32, + device=device) + else: + raise NotImplementedError( + "A 0 zero-point is not supported by Exllama due to " + "a bug in the original GPTQ checkpoint format leading to " + "exllama kernel adding 1 to the zero points during " + "inference") + zeros = pack_quantized_values_into_int32(zeros, + c.weight_type, + packed_dim=1) + setattr(layer, self.w_zp_name, + torch.nn.Parameter(zeros, requires_grad=False)) + + if c.has_g_idx: + + def transform_w_g_idx(x): + # Exllama wants the permutation array instead of the group + # indices + return torch.argsort(x).to(torch.int) + + self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) + else: + self.w_gidx_name = "g_idx" + empty_g_idx = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int, + device=device), + requires_grad=False) + setattr(layer, self.w_gidx_name, empty_g_idx) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + assert self.w_gidx_name is not None + g_idx = getattr(layer, self.w_gidx_name) + + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x_cont = x.data.contiguous() + ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits) + return x_cont + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x.to(dtype=c.act_type) + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) + + assert w_zp is not None, "Zero points are required by Exllama" + assert w_g_idx is not None, "Group index is required by Exllama" + output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, + c.weight_type.size_bits) + + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py index fa39cb511528e..e5696d08f30f5 100644 --- a/vllm/model_executor/layers/quantization/kernels/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -8,7 +8,7 @@ MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, query_machete_supported_quant_types) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_weights_into_int32, unpack_weights_into_int32) + pack_quantized_values_into_int32, unpack_quantized_values_into_int32) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) @@ -71,13 +71,13 @@ def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) if c.has_g_idx: - x_unpacked = unpack_weights_into_int32(x.data, - c.weight_type, - packed_dim=0) + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=0) x_perm = x_unpacked[perm, :] - x.data = pack_weights_into_int32(x_perm, - c.weight_type, - packed_dim=0) + x.data = pack_quantized_values_into_int32(x_perm, + c.weight_type, + packed_dim=0) x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), self.config.weight_type) return x diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 833d00073564e..c217f5ca620a1 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -20,9 +20,9 @@ } -def pack_weights_into_int32(w_q: torch.Tensor, - wtype: ScalarType, - packed_dim: int = 0): +def pack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): # move dim to pack to the end perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) inv_perm = tuple(perm.index(i) for i in range(len(perm))) @@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor, return res.permute(inv_perm) -def unpack_weights_into_int32(w_q: torch.Tensor, - wtype: ScalarType, - packed_dim: int = 0): +def unpack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): # move dim to pack to the end perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) inv_perm = tuple(perm.index(i) for i in range(len(perm))) diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index eb491dd1554a8..373151a5311e5 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -27,6 +27,8 @@ class scalar_types: float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) uint4b8 = ScalarType.uint(4, 8) uint8b128 = ScalarType.uint(8, 128)