From 2be62187d83c59285127336807d32cbd61a1fb05 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 1 Jul 2024 15:37:00 +0000 Subject: [PATCH 01/24] Support Repack from GPTQ. --- .gitignore | 3 + vllm/config.py | 3 +- vllm/model_executor/layers/linear.py | 23 +- .../layers/quantization/__init__.py | 4 + .../layers/quantization/bitblas.py | 364 +++++++++++++ .../layers/quantization/gptq_bitblas.py | 499 ++++++++++++++++++ 6 files changed, 891 insertions(+), 5 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/bitblas.py create mode 100644 vllm/model_executor/layers/quantization/gptq_bitblas.py diff --git a/.gitignore b/.gitignore index e077366d1e4a..47289386c0e5 100644 --- a/.gitignore +++ b/.gitignore @@ -187,3 +187,6 @@ hip_compat.h # Benchmark dataset *.json + +# Debug files +./debug diff --git a/vllm/config.py b/vllm/config.py index 3551e8f6fa03..d42762448964 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -230,7 +230,8 @@ def _verify_quantization(self) -> None: f"{self.quantization} quantization is currently not " f"supported in ROCm.") if (self.quantization - not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")): + not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", + "gptq_bitblas", "bitblas")): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3cc257834033..c567ad9d649f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -26,6 +26,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size +def adjust_bitblas_shard(param, shard_size, shard_offset): + bitblas_tile_size = getattr(param, "bitblas_tile_size", None) + if bitblas_tile_size is None: + return shard_size, shard_offset + + return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size + + def adjust_bitsandbytes_shard(param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str) -> Tuple[int, int]: @@ -306,7 +314,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert param_data.shape == loaded_weight.shape + assert param_data.dtype == loaded_weight.dtype, f"{param.data.dtype} != {loaded_weight.dtype}" + assert param_data.shape == loaded_weight.shape, f"{param_data.shape} != {loaded_weight.shape}" param_data.copy_(loaded_weight) def forward(self, input_): @@ -411,6 +420,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -432,7 +444,8 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) use_bitsandbytes = getattr(param, "use_bitsandbytes", False) if use_bitsandbytes: shard_size = loaded_weight.shape[output_dim] @@ -576,7 +589,8 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -608,7 +622,8 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) use_bitsandbytes = getattr(param, "use_bitsandbytes", False) if use_bitsandbytes: orig_qkv_offsets = { diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 40b0df75a69a..7eaf9c4832e0 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -16,7 +16,9 @@ GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.gptq_bitblas import GPTQBitBLASConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.bitblas import BitBLASConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { @@ -27,8 +29,10 @@ # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, + "bitblas": BitBLASConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, + "gptq_bitblas": GPTQBitBLASConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py new file mode 100644 index 000000000000..17f4211f6efa --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -0,0 +1,364 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + +try: + import bitblas +except ImportError as e: + bitblas_import_exception = e + raise ValueError( + f"Trying to use the bitblas backend, but could not import dependencies with the following error: {bitblas_import_exception}" + ) + +import bitblas +from bitblas.utils import auto_detect_nvidia_target + +BITBLAS_TARGET = auto_detect_nvidia_target() +BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() + +BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] +BITBLAS_SUPPORTED_SYM = [False, True] + +class BitBLASConfig(QuantizationConfig): + """Config class for BitBLAS. + + Reference: https://github.com/Microsoft/BitBLAS + """ + TORCH_DTYPE = torch.float16 + STORAGE_DTYPE = "int8" # assume int8 storage + TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) + ZEROS_TYPE = "quantized" # "original" or "rescale" or "quantized", the gptq_bitblas prefer "quantized" + + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + is_sym: bool) -> None: + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + + # Verify + if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.") + + storage_dtype = self.STORAGE_DTYPE + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + + self.storage_dtype = storage_dtype + self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_type = self.ZEROS_TYPE + + def __repr__(self) -> str: + return (f"BitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") + + @classmethod + def get_name(cls) -> str: + return "bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_bitblas_format: bool + is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" + or hf_quant_cfg.get("is_bitblas_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "bitblas") + + if is_bitblas_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["BitBLASLinearMethod"]: + if isinstance(layer, LinearBase): + return BitBLASLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class BitBLASLinearMethod(LinearMethodBase): + """Linear method for BitBLAS. + + Args: + quant_config: The BitBLAS quantization config. + """ + OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512] + ENABLE_TUNING = True + + def __init__(self, quant_config: BitBLASConfig): + self.quant_config = quant_config + + def create_weights( + self, + input_size_per_partition: int, + output_partition_sizes: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> Dict[str, Any]: + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing quantized weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_size_per_partition: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the input size per partition + is not divisible by the group size in `quant_config`. + """ + del input_size, output_size # Unused arguments. + if params_dtype != torch.float16: + raise ValueError( + f"Parameter data type must be torch.float16, but got {params_dtype}" + ) + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if ( + self.quant_config.group_size != -1 + and input_size_per_partition % self.quant_config.group_size != 0 + ): + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must be divisible by " + f"group size ({self.quant_config.group_size})." + ) + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + ) + # Initialize quantized weights with dimensions optimized for BitBLAS operations. + + qweight = Parameter( + torch.empty( + self.bitblas_matmul.retrieve_weight_shape(), + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + ), + requires_grad=False, + ) + # Attributes to help with unpacking and applying the weights later. + set_weight_attrs( + qweight, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "bitblas_tile_size": ( + self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.quant_config.weight_propagation + else None + ), + "pack_factor": self.quant_config.pack_factor, + "weight_propagation": self.quant_config.weight_propagation, + }, + ) + + # Compute the number of input groups for channel-wise quantization. + input_groups = ( + 1 + if self.quant_config.group_size == -1 + else input_size_per_partition // self.quant_config.group_size + ) + + # Initialize scales and zeros for the quantized weights. + scales = Parameter( + torch.empty( + output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, {"input_dim": None if input_groups == 1 else 1, "output_dim": 0}) + if self.quant_config.zeros_type == "quantized": + zeros = Parameter( + torch.empty( + input_groups, + output_size_per_partition // self.quant_config.pack_factor, + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + ), + requires_grad=False, + ) + # Set attributes to indicate how scales and zeros are applied. + + set_weight_attrs( + zeros, + { + "input_dim": None if input_groups == 1 else 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + else: + zeros = Parameter( + torch.empty(output_size_per_partition, input_groups, + device="cuda", + dtype=params_dtype), + requires_grad=False, + ) + # Set attributes to indicate how scales and zeros are applied. + set_weight_attrs(scales, {"input_dim": None if input_groups == 1 else 1, "output_dim": 0}) + + return {"qweight": qweight, "scales": scales, "zeros": zeros} + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + enable_tuning, + bias, + layout, + bits, + ): + from bitblas import MatmulConfig + + bitblas_dtype = self.BITBLAS_DTYPES[self.TORCH_DTYPE] + + W_dtype = f"uint{bits}" + + matmul_config = MatmulConfig( + M=self.OPT_FEATURES, + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=bitblas_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=self.STORAGE_DTYPE, + with_scaling=True, + with_zeros=True, + group_size=self.group_size, + with_bias=bias, + layout=layout, + zeros_mode=self.zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning + ) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul + from bitblas.cache import global_operator_cache + + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, target=self.target) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) + print( + "BitBLAS Tuning done, appended operator to global_operator_cache." + ) + else: + print("BitBLAS Operator created.") + else: + print("BitBLAS Operator found in global_operator_cache.") + return bitblas_matmul + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py new file mode 100644 index 000000000000..780154d7047c --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -0,0 +1,499 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +import bitblas.cache +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + set_weight_attrs, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +logger = init_logger(__name__) + + +try: + import bitblas +except ImportError as e: + bitblas_import_exception = e + raise ValueError( + f"Trying to use the bitblas backend, but could not import dependencies with the following error: {bitblas_import_exception}" + ) + +import bitblas +from bitblas.utils import auto_detect_nvidia_target + +bitblas.set_log_level("Debug") +BITBLAS_TARGET = auto_detect_nvidia_target() +BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() + +GPTQ_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] +GPTQ_BITBLAS_SUPPORTED_SYM = [False, True] + + +def unpack_qzeros(qzeros, bits): + qzeros = qzeros.view(torch.int32) + elems_per_int32 = 32 // bits + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + + for col in range(unpacked_zeros.shape[1]): + i = col % elems_per_int32 + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF + + return unpacked_zeros + 1 + + + +class GPTQBitBLASConfig(QuantizationConfig): + """Config class for GPTQ BitBLAS""" + + TORCH_DTYPE = torch.float16 + GPTQ_CKPT_STORAGE_DTYPE = "int32" # GPTQ Default Checkpoints use int32 as storage dtype + GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype + TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE) + ZEROS_MODE = "quantized" # "original" or "rescale" or "quantized", the gptq_bitblas prefer "quantized" + + def __init__( + self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool + ) -> None: + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + + # Verify + if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} " + "are supported." + ) + + if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported." + ) + + + self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE + + storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE if c.isdigit())) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + def __repr__(self) -> str: + return ( + f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})" + ) + + @classmethod + def get_name(cls) -> str: + return "gptq_bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + return cls(weight_bits, group_size, desc_act, is_sym) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + can_convert = cls.is_bitblas_compatible(hf_quant_cfg) + + is_valid_user_quant = user_quant is None or user_quant == "bitblas" + + if can_convert and is_valid_user_quant: + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info( + "Detected that the model can run with gptq_bitblas" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_bitblas for" + " faster inference" + ) + return None + + def get_quant_method( + self, layer: torch.nn.Module + ) -> Optional["GPTQBitBLASLinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQBitBLASLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_bitblas_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits", None) + group_size = quant_config.get("group_size", None) + sym = quant_config.get("sym", None) + desc_act = quant_config.get("desc_act", None) + + # If we cannot find the info needed in the config, cannot convert. + if num_bits is None or group_size is None or sym is None or desc_act is None: + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies bitblas constraints. + return ( + num_bits in GPTQ_BITBLAS_SUPPORTED_NUM_BITS + and sym in GPTQ_BITBLAS_SUPPORTED_SYM + ) + + +class GPTQBitBLASState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +class GPTQBitBLASLinearMethod(LinearMethodBase): + """Linear method for GPTQ BitBLAS. + + Args: + quant_config: The GPTQ BitBLAS quantization config. + """ + + OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512] + ENABLE_TUNING = True + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__(self, quant_config: GPTQBitBLASConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing quantized weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_partition_sizes: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the input size per partition + is not divisible by the group size in `quant_config`. + """ + del output_size # Unused arguments. + if params_dtype != torch.float16: + raise ValueError( + f"Parameter data type must be torch.float16, but got {params_dtype}" + ) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + if input_size_per_partition % group_size != 0: + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must be divisible by " + f"group size ({self.quant_config.group_size})." + ) + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + # By default, no sharding over "input dim" + scales_and_zp_size = input_size // group_size + scales_and_zp_input_dim = None + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + ) + + # Init buffers + # Quantized weights + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + **extra_weight_attrs, + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + # Activation order + g_idx = Parameter( + torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + # Ignore warning from fused linear layers such as QKVParallelLinear. + set_weight_attrs( + g_idx, + {**extra_weight_attrs, "input_dim": 0, "ignore_warning": True}, + ) + + g_idx_sort_indices = torch.empty( + g_idx.shape, + dtype=torch.int32, + ) + + # Scales + scales = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + }, + ) + + # Quantized zero-points + qzeros = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + layer.g_idx_sort_indices = g_idx_sort_indices + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.input_size = input_size + layer.bitblas_state = GPTQBitBLASState.REPACK + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ): + from bitblas import MatmulConfig + + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + + W_dtype = f"uint{bits}" + + matmul_config = MatmulConfig( + M=self.OPT_FEATURES, + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=bitblas_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=self.quant_config.GPTQ_BITBLAS_STORAGE_DTYPE, + with_scaling=True, + with_zeros=True, + group_size=self.quant_config.group_size, + with_bias=bias, + layout=layout, + zeros_mode=self.quant_config.zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning + ) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul + from bitblas.cache import global_operator_cache + + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) + print( + "BitBLAS Tuning done, appended operator to global_operator_cache." + ) + else: + print(f"BitBLAS Operator {config} created.") + else: + print(f"BitBLAS Operator {config} found in global_operator_cache.") + return bitblas_matmul + + def repack_bitblas_from_gptq(self, b_q_weight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor): + from bitblas.quantization.utils import general_compress + + # qweight in gptq old quant linear stored with (outfeatures, infeatures), should be transposed. + qweight = b_q_weight.T.contiguous().view(self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE) + if self.bitblas_matmul.weight_transform is not None: + qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() + # scales in gptq old quant linear stored with (infeatures // group_size, outfeatures), should be transposed. + scales = scales.T.contiguous() + # qzeros should be de-quantized to int zeros. + intzeros = unpack_qzeros(qzeros, self.quant_config.weight_bits).T.contiguous() + zeros = None + if self.bitblas_matmul.config.zeros_mode == "original": + zeros = intzeros.to(torch.float16).contiguous() + elif self.bitblas_matmul.config.zeros_mode == "rescale": + zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :] + elif self.bitblas_matmul.config.zeros_mode == "quantized": + zeros = ( + torch.Tensor( + general_compress(intzeros.T.contiguous().cpu().numpy(), self.quant_config.weight_bits) + ) + .to(qweight.device) + .to(self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE) + .contiguous() + ) + else: + raise ValueError( + f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}" + ) + + return qweight, scales, zeros + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + part_size_n = layer.output_size_per_partition + out_shape = x.shape[:-1] + (part_size_n,) + + if layer.bitblas_state == GPTQBitBLASState.REPACK: + layer.bitblas_state = GPTQBitBLASState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by vLLM (and won't be freed) + def replace_tensor(name, new_t): + # It is important to use copy_() here since it ensures + # the same buffer is reused + getattr(layer, name).copy_(new_t.view(getattr(layer, name).dtype).view(getattr(layer, name).shape)) + del new_t + + # Repack weights + bitblas_qweight, bitblas_scales, bitblas_qzeros = self.repack_bitblas_from_gptq( + layer.qweight, + layer.scales, + layer.qzeros, + ) + replace_tensor("qweight", bitblas_qweight) + replace_tensor("scales", bitblas_scales) + replace_tensor("qzeros", bitblas_qzeros) + + output = self.bitblas_matmul(x, layer.qweight, layer.scales, layer.qzeros) + + if bias is not None: + output.add_(bias) # In-place add + return output.reshape(out_shape) From b92de929d78b5aa27b0918b5d0b086fe6010f967 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 1 Jul 2024 16:10:17 +0000 Subject: [PATCH 02/24] chore: Remove unused input_size and output_size variables in MarlinLinearMethod constructor --- .../layers/quantization/bitblas.py | 132 +++++++++------- .../layers/quantization/gptq_bitblas.py | 143 +++++++++--------- .../layers/quantization/marlin.py | 2 +- 3 files changed, 146 insertions(+), 131 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 17f4211f6efa..9cad817608aa 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -29,6 +29,7 @@ BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] BITBLAS_SUPPORTED_SYM = [False, True] + class BitBLASConfig(QuantizationConfig): """Config class for BitBLAS. @@ -37,7 +38,7 @@ class BitBLASConfig(QuantizationConfig): TORCH_DTYPE = torch.float16 STORAGE_DTYPE = "int8" # assume int8 storage TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) - ZEROS_TYPE = "quantized" # "original" or "rescale" or "quantized", the gptq_bitblas prefer "quantized" + ZEROS_MODE = "quantized" # "original" or "rescale" or "quantized", the gptq_bitblas prefer "quantized" def __init__(self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool) -> None: @@ -73,7 +74,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, self.nbits = weight_bits # Zeros type for the quantized weights. - self.zeros_type = self.ZEROS_TYPE + self.zeros_mode = self.ZEROS_MODE def __repr__(self) -> str: return (f"BitBLASConfig(weight_bits={self.weight_bits}, " @@ -99,8 +100,11 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) - return cls(group_size) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + return cls(weight_bits, group_size, desc_act, is_sym) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -108,7 +112,7 @@ def override_quantization_method(cls, hf_quant_cfg, # compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq <=0.7.1 is_bitblas_format: bool is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" - or hf_quant_cfg.get("is_bitblas_format", False)) + or hf_quant_cfg.get("is_bitblas_format", False)) is_valid_user_quant = (user_quant is None or user_quant == "gptq" or user_quant == "bitblas") @@ -139,19 +143,26 @@ class BitBLASLinearMethod(LinearMethodBase): """ OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512] ENABLE_TUNING = True + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.half: "float16", + torch.int8: "int8", + } def __init__(self, quant_config: BitBLASConfig): self.quant_config = quant_config def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, - ) -> Dict[str, Any]: + ): """Creates quantized weights for use in linear operations. The function initializes and returns a dictionary containing quantized weights, scales, and zeros @@ -179,19 +190,17 @@ def create_weights( # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) - if ( - self.quant_config.group_size != -1 - and input_size_per_partition % self.quant_config.group_size != 0 - ): + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): raise ValueError( f"Input size per partition ({input_size_per_partition}) must be divisible by " - f"group size ({self.quant_config.group_size})." - ) + f"group size ({self.quant_config.group_size}).") # Initialize or retrieve the BitBLAS matrix multiplication operator. self._configure_bitblas_matmul( input_size_per_partition, output_size_per_partition, + params_dtype=params_dtype, enable_tuning=self.ENABLE_TUNING, bias=False, layout="nt", @@ -211,25 +220,26 @@ def create_weights( set_weight_attrs( qweight, { - "input_dim": 1, - "output_dim": 0, - "packed_dim": 1, - "bitblas_tile_size": ( - self.bitblas_matmul.retrieve_weight_shape()[-2] - if self.quant_config.weight_propagation - else None - ), - "pack_factor": self.quant_config.pack_factor, - "weight_propagation": self.quant_config.weight_propagation, + "input_dim": + 1, + "output_dim": + 0, + "packed_dim": + 1, + "bitblas_tile_size": + (self.bitblas_matmul.retrieve_weight_shape()[-2] if + self.bitblas_matmul.transform_weight is not None else None), + "pack_factor": + self.quant_config.pack_factor, + "weight_propagation": + self.bitblas_matmul.transform_weight is not None, }, ) # Compute the number of input groups for channel-wise quantization. - input_groups = ( - 1 - if self.quant_config.group_size == -1 - else input_size_per_partition // self.quant_config.group_size - ) + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) # Initialize scales and zeros for the quantized weights. scales = Parameter( @@ -241,8 +251,11 @@ def create_weights( ), requires_grad=False, ) - set_weight_attrs(scales, {"input_dim": None if input_groups == 1 else 1, "output_dim": 0}) - if self.quant_config.zeros_type == "quantized": + set_weight_attrs(scales, { + "input_dim": None if input_groups == 1 else 1, + "output_dim": 0 + }) + if self.quant_config.zeros_mode == "quantized": zeros = Parameter( torch.empty( input_groups, @@ -265,20 +278,30 @@ def create_weights( ) else: zeros = Parameter( - torch.empty(output_size_per_partition, input_groups, + torch.empty(output_size_per_partition, + input_groups, device="cuda", dtype=params_dtype), requires_grad=False, ) # Set attributes to indicate how scales and zeros are applied. - set_weight_attrs(scales, {"input_dim": None if input_groups == 1 else 1, "output_dim": 0}) + set_weight_attrs(scales, { + "input_dim": None if input_groups == 1 else 1, + "output_dim": 0 + }) + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("zeros", zeros) + set_weight_attrs(zeros, extra_weight_attrs) - return {"qweight": qweight, "scales": scales, "zeros": zeros} - def _configure_bitblas_matmul( self, infeatures, outfeatures, + params_dtype, enable_tuning, bias, layout, @@ -286,7 +309,7 @@ def _configure_bitblas_matmul( ): from bitblas import MatmulConfig - bitblas_dtype = self.BITBLAS_DTYPES[self.TORCH_DTYPE] + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] W_dtype = f"uint{bits}" @@ -298,63 +321,56 @@ def _configure_bitblas_matmul( W_dtype=W_dtype, out_dtype=bitblas_dtype, accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, - storage_dtype=self.STORAGE_DTYPE, + storage_dtype=self.quant_config.STORAGE_DTYPE, with_scaling=True, with_zeros=True, - group_size=self.group_size, + group_size=self.quant_config.group_size, with_bias=bias, layout=layout, - zeros_mode=self.zeros_mode, + zeros_mode=self.quant_config.zeros_mode, ) self.bitblas_matmul = self._get_or_create_bitblas_operator( - matmul_config, enable_tuning - ) + matmul_config, enable_tuning) def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul from bitblas.cache import global_operator_cache if global_operator_cache.size() == 0: - global_operator_cache.load_from_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET - ) + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, target=self.target) + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET) if enable_tuning: bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET - ) - print( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + logger.info( "BitBLAS Tuning done, appended operator to global_operator_cache." ) else: - print("BitBLAS Operator created.") + logger.info(f"BitBLAS Operator {config} created.") else: - print("BitBLAS Operator found in global_operator_cache.") + logger.info( + f"BitBLAS Operator {config} found in global_operator_cache.") return bitblas_matmul - + def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = layer.B - scales = layer.s - workspace = layer.workspace + qweight = layer.qweight + scales = layer.scales + qzeros = layer.zeros x_2d = x.view(-1, x.shape[-1]) - size_m = x_2d.shape[0] - size_k = x_2d.shape[1] - size_n = scales.shape[1] - - output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, - size_n, size_k) + output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 780154d7047c..4eaf2d69692b 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -17,7 +17,6 @@ logger = init_logger(__name__) - try: import bitblas except ImportError as e: @@ -37,7 +36,7 @@ GPTQ_BITBLAS_SUPPORTED_SYM = [False, True] -def unpack_qzeros(qzeros, bits): +def unpack_qzeros(qzeros, bits) -> torch.Tensor: qzeros = qzeros.view(torch.int32) elems_per_int32 = 32 // bits unpacked_zeros = torch.zeros( @@ -49,24 +48,25 @@ def unpack_qzeros(qzeros, bits): for col in range(unpacked_zeros.shape[1]): i = col % elems_per_int32 - unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> + (bits * i)) & 0xF return unpacked_zeros + 1 - class GPTQBitBLASConfig(QuantizationConfig): """Config class for GPTQ BitBLAS""" TORCH_DTYPE = torch.float16 - GPTQ_CKPT_STORAGE_DTYPE = "int32" # GPTQ Default Checkpoints use int32 as storage dtype - GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype + GPTQ_CKPT_STORAGE_DTYPE = ( + "int32" # GPTQ Default Checkpoints use int32 as storage dtype + ) + GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE) ZEROS_MODE = "quantized" # "original" or "rescale" or "quantized", the gptq_bitblas prefer "quantized" - def __init__( - self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool - ) -> None: + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + is_sym: bool) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -82,19 +82,17 @@ def __init__( raise ValueError( f"BitBLAS does not support weight_bits = {self.weight_bits}. " f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} " - "are supported." - ) + "are supported.") if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM: raise ValueError( f"BitBLAS does not support is_sym = {self.is_sym}. " - f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported." - ) + f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.") - self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE - storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE if c.isdigit())) + storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE + if c.isdigit())) # 4 Bits packed into 32 bit datatype. self.pack_factor = storage_nbit // weight_bits @@ -104,11 +102,9 @@ def __init__( self.zeros_mode = self.ZEROS_MODE def __repr__(self) -> str: - return ( - f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act})" - ) + return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") @classmethod def get_name(cls) -> str: @@ -135,31 +131,28 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig": return cls(weight_bits, group_size, desc_act, is_sym) @classmethod - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: can_convert = cls.is_bitblas_compatible(hf_quant_cfg) is_valid_user_quant = user_quant is None or user_quant == "bitblas" if can_convert and is_valid_user_quant: - msg = ( - "The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name()) - ) + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) logger.info(msg) return cls.get_name() if can_convert and user_quant == "gptq": - logger.info( - "Detected that the model can run with gptq_bitblas" - ", however you specified quantization=gptq explicitly," - " so forcing gptq. Use quantization=gptq_bitblas for" - " faster inference" - ) + logger.info("Detected that the model can run with gptq_bitblas" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_bitblas for" + " faster inference") return None def get_quant_method( - self, layer: torch.nn.Module - ) -> Optional["GPTQBitBLASLinearMethod"]: + self, + layer: torch.nn.Module) -> Optional["GPTQBitBLASLinearMethod"]: if isinstance(layer, LinearBase): return GPTQBitBLASLinearMethod(self) return None @@ -186,10 +179,8 @@ def is_bitblas_compatible(cls, quant_config: Dict[str, Any]): return False # Otherwise, can convert if model satisfies bitblas constraints. - return ( - num_bits in GPTQ_BITBLAS_SUPPORTED_NUM_BITS - and sym in GPTQ_BITBLAS_SUPPORTED_SYM - ) + return (num_bits in GPTQ_BITBLAS_SUPPORTED_NUM_BITS + and sym in GPTQ_BITBLAS_SUPPORTED_SYM) class GPTQBitBLASState(Enum): @@ -260,8 +251,7 @@ def create_weights( if input_size_per_partition % group_size != 0: raise ValueError( f"Input size per partition ({input_size_per_partition}) must be divisible by " - f"group size ({self.quant_config.group_size})." - ) + f"group size ({self.quant_config.group_size}).") # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) @@ -311,7 +301,10 @@ def create_weights( # Ignore warning from fused linear layers such as QKVParallelLinear. set_weight_attrs( g_idx, - {**extra_weight_attrs, "input_dim": 0, "ignore_warning": True}, + { + **extra_weight_attrs, "input_dim": 0, + "ignore_warning": True + }, ) g_idx_sort_indices = torch.empty( @@ -399,17 +392,15 @@ def _configure_bitblas_matmul( zeros_mode=self.quant_config.zeros_mode, ) self.bitblas_matmul = self._get_or_create_bitblas_operator( - matmul_config, enable_tuning - ) + matmul_config, enable_tuning) def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul from bitblas.cache import global_operator_cache if global_operator_cache.size() == 0: - global_operator_cache.load_from_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET - ) + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: @@ -418,42 +409,46 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET - ) - print( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + logger.info( "BitBLAS Tuning done, appended operator to global_operator_cache." ) else: - print(f"BitBLAS Operator {config} created.") + logger.info(f"BitBLAS Operator {config} created.") else: - print(f"BitBLAS Operator {config} found in global_operator_cache.") + logger.info( + f"BitBLAS Operator {config} found in global_operator_cache.") return bitblas_matmul - - def repack_bitblas_from_gptq(self, b_q_weight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor): + + def repack_bitblas_from_gptq(self, b_q_weight: torch.Tensor, + scales: torch.Tensor, qzeros: torch.Tensor): from bitblas.quantization.utils import general_compress # qweight in gptq old quant linear stored with (outfeatures, infeatures), should be transposed. - qweight = b_q_weight.T.contiguous().view(self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE) + qweight = b_q_weight.T.contiguous().view( + self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE) if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() + qweight = self.bitblas_matmul.weight_transform( + qweight.cpu()).cuda() # scales in gptq old quant linear stored with (infeatures // group_size, outfeatures), should be transposed. scales = scales.T.contiguous() # qzeros should be de-quantized to int zeros. - intzeros = unpack_qzeros(qzeros, self.quant_config.weight_bits).T.contiguous() - zeros = None + intzeros = unpack_qzeros(qzeros, + self.quant_config.weight_bits).T.contiguous() + zeros: Optional[torch.Tensor] = None if self.bitblas_matmul.config.zeros_mode == "original": zeros = intzeros.to(torch.float16).contiguous() elif self.bitblas_matmul.config.zeros_mode == "rescale": + assert zeros is not None, "zeros should not be None" zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :] elif self.bitblas_matmul.config.zeros_mode == "quantized": - zeros = ( - torch.Tensor( - general_compress(intzeros.T.contiguous().cpu().numpy(), self.quant_config.weight_bits) - ) - .to(qweight.device) - .to(self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE) - .contiguous() - ) + zeros = (torch.Tensor( + general_compress( + intzeros.T.contiguous().cpu().numpy(), + self.quant_config.weight_bits, + )).to(qweight.device).to( + self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE).contiguous( + )) else: raise ValueError( f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}" @@ -469,7 +464,7 @@ def apply( ) -> torch.Tensor: part_size_n = layer.output_size_per_partition - out_shape = x.shape[:-1] + (part_size_n,) + out_shape = x.shape[:-1] + (part_size_n, ) if layer.bitblas_state == GPTQBitBLASState.REPACK: layer.bitblas_state = GPTQBitBLASState.READY @@ -479,20 +474,24 @@ def apply( def replace_tensor(name, new_t): # It is important to use copy_() here since it ensures # the same buffer is reused - getattr(layer, name).copy_(new_t.view(getattr(layer, name).dtype).view(getattr(layer, name).shape)) + getattr(layer, name).copy_( + new_t.view(getattr(layer, name).dtype).view( + getattr(layer, name).shape)) del new_t # Repack weights - bitblas_qweight, bitblas_scales, bitblas_qzeros = self.repack_bitblas_from_gptq( - layer.qweight, - layer.scales, - layer.qzeros, - ) + bitblas_qweight, bitblas_scales, bitblas_qzeros = ( + self.repack_bitblas_from_gptq( + layer.qweight, + layer.scales, + layer.qzeros, + )) replace_tensor("qweight", bitblas_qweight) replace_tensor("scales", bitblas_scales) replace_tensor("qzeros", bitblas_qzeros) - output = self.bitblas_matmul(x, layer.qweight, layer.scales, layer.qzeros) + output = self.bitblas_matmul(x, layer.qweight, layer.scales, + layer.qzeros) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 3613c9d9ecf2..ea1e73ebbd95 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -124,7 +124,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - del output_size # Unused. + del input_size, output_size # Unused. if params_dtype != torch.float16: raise ValueError( From 71ea469b487c577f99bf36e07365c1e5f42472e3 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 16 Jul 2024 03:19:16 +0000 Subject: [PATCH 03/24] Support BitNet Model for 1.58bits. --- .gitignore | 2 +- vllm/config.py | 23 +- .../layers/quantization/__init__.py | 3 + .../layers/quantization/bitnet_bitblas.py | 387 +++++++++ .../layers/quantization/gptq_bitblas.py | 3 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/bitnet.py | 783 ++++++++++++++++++ vllm/transformers_utils/tokenizer.py | 16 +- .../transformers_utils/tokenizers/__init__.py | 2 + vllm/transformers_utils/tokenizers/bitnet.py | 504 +++++++++++ 10 files changed, 1717 insertions(+), 7 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/bitnet_bitblas.py create mode 100644 vllm/model_executor/models/bitnet.py create mode 100644 vllm/transformers_utils/tokenizers/bitnet.py diff --git a/.gitignore b/.gitignore index 47289386c0e5..97a4a3774f62 100644 --- a/.gitignore +++ b/.gitignore @@ -189,4 +189,4 @@ hip_compat.h *.json # Debug files -./debug +debug/ diff --git a/vllm/config.py b/vllm/config.py index d42762448964..dce48eeda56b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -297,6 +297,21 @@ def get_vocab_size(self) -> int: def get_hidden_size(self) -> int: return self.hf_text_config.hidden_size + def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: + """ + Find the closest head dimension to the given head dimension that is supported by Flash Attention. + """ + from vllm.attention.backends.flash_attn import FlashAttentionBackend + FLASHATTN_SUPPORTED_HEAD_DIMS = FlashAttentionBackend.get_supported_head_sizes() + + for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS: + if head_dim <= supported_head_dim: + return supported_head_dim + raise ValueError( + f"Head dimension {head_dim} is not supported by Flash Attention. Supported head dimensions are " + f"{FLASHATTN_SUPPORTED_HEAD_DIMS}." + ) + def get_head_size(self) -> int: # TODO remove hard code if hasattr(self.hf_text_config, "model_type" @@ -304,6 +319,11 @@ def get_head_size(self) -> int: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 + if hasattr(self.hf_text_config, "architectures" + ) and 'BitnetForCausalLM' in self.hf_text_config.architectures: + return self.find_flash_attn_supported_head_dims((self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads)) + if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim # FIXME(woosuk): This may not be true for all models. @@ -742,7 +762,8 @@ def __init__(self, else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + # self.max_num_batched_tokens = max(max_model_len + self.max_num_batched_tokens = 1024 if enable_chunked_prefill: logger.info("Chunked prefill is enabled (EXPERIMENTAL).") diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 7eaf9c4832e0..4585ae98bb0b 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -6,6 +6,8 @@ QuantizationConfig) from vllm.model_executor.layers.quantization.bitsandbytes import ( BitsAndBytesConfig) +from vllm.model_executor.layers.quantization.bitnet_bitblas import ( + BITNETBitBLASConfig) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( @@ -28,6 +30,7 @@ "fp8": Fp8Config, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) + "bitnet_bitblas": BITNETBitBLASConfig, "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gptq_marlin_24": GPTQMarlin24Config, diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py new file mode 100644 index 000000000000..0b3f7ca14526 --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -0,0 +1,387 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +import bitblas.cache +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + set_weight_attrs, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +logger = init_logger(__name__) + +try: + import bitblas +except ImportError as e: + bitblas_import_exception = e + raise ValueError( + f"Trying to use the bitblas backend, but could not import dependencies with the following error: {bitblas_import_exception}" + ) + +import bitblas +from bitblas.utils import auto_detect_nvidia_target + +bitblas.set_log_level("Debug") +BITBLAS_TARGET = auto_detect_nvidia_target() +BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() + +BITNET_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] + + +class BITNETBitBLASConfig(QuantizationConfig): + """Config class for BITNET BitBLAS""" + + TORCH_DTYPE = torch.int8 + BITNET_CKPT_STORAGE_DTYPE = ( + "float16" # BITNET Default Checkpoints use float16 as storage dtype + ) + BITNET_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype + TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, BITNET_BITBLAS_STORAGE_DTYPE) + + def __init__(self, weight_bits: int, is_sym: bool) -> None: + self.input_bits = 8 + self.weight_bits = weight_bits + self.is_sym = is_sym + + # Verify + if self.weight_bits not in BITNET_BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {BITNET_BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + self.storage_dtype = self.BITNET_BITBLAS_STORAGE_DTYPE + self.nbits = weight_bits + + def __repr__(self) -> str: + return (f"BITNETBitBLASConfig(weight_bits={self.weight_bits}, is_sym={self.is_sym})") + + @classmethod + def get_name(cls) -> str: + return "bitnet_bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.int8] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BITNETBitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + is_sym = cls.get_from_keys(config, ["sym"]) + return cls(weight_bits, is_sym) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_bitblas_compatible(hf_quant_cfg) + + is_valid_user_quant = user_quant is None or user_quant == "bitblas" + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "bitnet": + logger.info("Detected that the model can run with bitnet_bitblas" + ", however you specified quantization=bitnet explicitly," + " so forcing bitnet. Use quantization=bitnet_bitblas for" + " faster inference") + return None + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["BITNETBitBLASLinearMethod"]: + if isinstance(layer, LinearBase): + return BITNETBitBLASLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_bitblas_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits", None) + sym = quant_config.get("sym", None) + + # If we cannot find the info needed in the config, cannot convert. + if num_bits is None or sym is None: + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies bitblas constraints. + return num_bits in BITNET_BITBLAS_SUPPORTED_NUM_BITS + + +class BITNETBitBLASState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +class BITNETBitBLASLinearMethod(LinearMethodBase): + """Linear method for BITNET BitBLAS. + + Args: + quant_config: The BITNET BitBLAS quantization config. + """ + + ENABLE_TUNING = True + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__(self, quant_config: BITNETBitBLASConfig) -> None: + self.quant_config = quant_config + self.Qp = 2**(quant_config.input_bits - 1) - 1 + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing quantized weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_partition_sizes: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the input size per partition + is not divisible by the group size in `quant_config`. + """ + del output_size # Unused arguments. + if params_dtype != torch.float16: + raise ValueError( + f"Parameter data type must be torch.float16, but got {params_dtype}" + ) + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + bitblas_dtype = "int8" + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + bitblas_dtype=bitblas_dtype, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + ) + + # Init buffers + # Quantized weights + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float16, + ), + requires_grad=False, + ) + set_weight_attrs( + weight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) + + qweight = Parameter( + torch.empty( + *self.bitblas_matmul.retrieve_weight_shape(), + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) + + layer.register_parameter("weight", weight) + layer.register_parameter("qweight", qweight) + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.input_size = input_size + layer.bitblas_state = BITNETBitBLASState.REPACK + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + bitblas_dtype, + enable_tuning, + bias, + layout, + bits, + ): + from bitblas import MatmulConfig + + W_dtype = f"int{bits}" + + matmul_config = MatmulConfig( + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype="float32", + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=self.quant_config.BITNET_BITBLAS_STORAGE_DTYPE, + with_scaling=False, + with_zeros=False, + with_bias=bias, + layout=layout, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul + from bitblas.cache import global_operator_cache + + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + logger.info( + "BitBLAS Tuning done, appended operator to global_operator_cache." + ) + else: + logger.info(f"BitBLAS Operator {config} created.") + else: + logger.info( + f"BitBLAS Operator {config} found in global_operator_cache.") + return bitblas_matmul + + def weight_quant(self, weight): + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) + return result.type(torch.int8) + + def activation_quant(self, x, num_bits=8): + x = x.float() + Qn = -(2**(num_bits - 1)) + Qp = 2**(num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) + return result.type(torch.int8) + + def repack_bitblas_from_bitnet(self, b_q_weight: torch.Tensor, is_qkv_packed: bool=False): + if is_qkv_packed: + hidden_size = b_q_weight.size(0) + sw_q = 1 / b_q_weight[:hidden_size // 3].abs().mean().clamp(min=1e-5) + sw_k = 1 / b_q_weight[hidden_size // 3:2 * hidden_size // 3].abs().mean().clamp(min=1e-5) + sw_v = 1 / b_q_weight[2 * hidden_size // 3:].abs().mean().clamp(min=1e-5) + self.sw_q = sw_q + self.sw_k = sw_k + self.sw_v = sw_v + qweight_q = self.weight_quant(b_q_weight[:hidden_size // 3]).detach() + qweight_k = self.weight_quant(b_q_weight[hidden_size // 3:2 * hidden_size // 3]).detach() + qweight_v = self.weight_quant(b_q_weight[2 * hidden_size // 3:]).detach() + qweight = torch.cat([qweight_q, qweight_k, qweight_v], dim=0) + else: + sw = 1 / b_q_weight.abs().mean().clamp(min=1e-5) + self.sw = sw + qweight = self.weight_quant(b_q_weight).detach() + qweight = self.bitblas_matmul.transform_weight(qweight) + return qweight + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + part_size_n = layer.output_size_per_partition + out_shape = x.shape[:-1] + (part_size_n, ) + quant_input = self.activation_quant(x, self.quant_config.input_bits).detach() + + if layer.bitblas_state == BITNETBitBLASState.REPACK: + layer.bitblas_state = BITNETBitBLASState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by vLLM (and won't be freed) + def free_tensor(name): + # free the original weight tensor + delattr(layer, name) + def replace_tensor(name, new_t): + # Cannot use copy_() because the storage shape and dtype are different + # del layer._parameters[name] + delattr(layer, name) + setattr(layer, name, new_t) + + # Repack weights + bitblas_qweight = self.repack_bitblas_from_bitnet( + layer.weight, + ) + # free the original weight tensor + free_tensor("weight") + replace_tensor("qweight", bitblas_qweight) + + fp32_out = self.bitblas_matmul(quant_input, layer.qweight) + sw = self.sw + Qp = self.Qp + si = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + # if / (si * sw) it will inf in some cases + output = fp32_out / si + output = output / sw + output = output.half() + output = output.type(x.dtype) + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 4eaf2d69692b..75f9b087f07a 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -6,7 +6,6 @@ from torch.nn.parameter import Parameter import bitblas.cache -from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import ( LinearBase, @@ -404,7 +403,7 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, target=BITBLAS_TARGET) + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) if enable_tuning: bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 69a65ff023bc..9a411c4c2874 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -15,6 +15,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "BitnetForCausalLM": ("bitnet", "BitnetForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"), diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py new file mode 100644 index 000000000000..e35ff4230762 --- /dev/null +++ b/vllm/model_executor/models/bitnet.py @@ -0,0 +1,783 @@ +# coding=utf-8 +# Adapted from +# https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/modeling_bitnet.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Bitnet model compatible with HuggingFace weights.""" +from typing import Dict, Iterable, List, Optional, Tuple +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput +from vllm.utils import is_hip, print_warning_once +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BitnetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BitnetModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BitnetModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Bitnet 1 supports up to 2048 tokens, + Bitnet 2 up to 4096, CodeBitnet up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import BitnetModel, BitnetConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = BitnetConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = BitnetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + weight_bits=1, + input_bits=8, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.weight_bits = weight_bits + self.input_bits = input_bits + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if ( + rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0 + ): + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}" + ) + + +class BitnetMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + config: BitnetConfig = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + self.ffn_layernorm = BitnetRMSNorm(intermediate_size, eps=config.rms_norm_eps) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.ffn_layernorm(x) + x, _ = self.down_proj(x) + return x + + +class BitnetRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64 + ).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False + ) + self.register_buffer( + "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False + ) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class" + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class" + ) + return self._cos_cached + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BitnetAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, float]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + config: BitnetConfig = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = hidden_size // self.total_num_heads + self.padded_head_dim = self.find_flash_attn_supported_head_dims(self.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = config.attention_dropout + + self.q_proj = RowParallelLinear( + input_size=hidden_size, + output_size=self.head_dim * self.num_heads, + bias=bias, + quant_config=quant_config, + ) + + self.k_proj = RowParallelLinear( + input_size=hidden_size, + output_size=self.head_dim * self.num_heads, + bias=bias, + quant_config=quant_config, + ) + + self.v_proj = RowParallelLinear( + input_size=hidden_size, + output_size=self.head_dim * self.num_heads, + bias=bias, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) + + self.attn = Attention( + self.num_heads, + self.padded_head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + self.inner_attn_ln = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: + """ + Find the closest head dimension to the given head dimension that is supported by Flash Attention. + """ + from vllm.attention.backends.flash_attn import FlashAttentionBackend + + FLASHATTN_SUPPORTED_HEAD_DIMS = FlashAttentionBackend.get_supported_head_sizes() + for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS: + if head_dim <= supported_head_dim: + return supported_head_dim + raise ValueError( + f"Head dimension {head_dim} is not supported by Flash Attention. Supported head dimensions are " + f"{FLASHATTN_SUPPORTED_HEAD_DIMS}." + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # QKV projection cannot be grouped as the they do not share the same scaling factor + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + q, k = self.rotary_emb(positions, q, k) + # Padding as paged attention doesn't support head_dim == 100 + q = torch.nn.functional.pad( + q.view(-1, self.total_num_heads, self.head_dim), + (0, self.padded_head_dim - self.head_dim), + ).view(-1, self.total_num_heads * self.padded_head_dim) + k = torch.nn.functional.pad( + k.view(-1, self.num_kv_heads, self.head_dim), + (0, self.padded_head_dim - self.head_dim), + ).view(-1, self.total_num_kv_heads * self.padded_head_dim) + v = torch.nn.functional.pad( + v.view(-1, self.num_kv_heads, self.head_dim), + (0, self.padded_head_dim - self.head_dim), + ).view(-1, self.total_num_kv_heads * self.padded_head_dim) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view(-1, self.total_num_heads, self.padded_head_dim)[ + ..., : self.head_dim + ].reshape(-1, self.total_num_heads * self.head_dim) + attn_output = self.inner_attn_ln(attn_output) + output, _ = self.o_proj(attn_output) + return output + + +class BitnetRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BitnetRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class BitnetDecoderLayer(nn.Module): + + def __init__( + self, + config: BitnetConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + self.self_attn = BitnetAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + config=config, + ) + self.mlp = BitnetMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + config=config, + ) + self.input_layernorm = BitnetRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = BitnetRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class BitnetModel(nn.Module): + + def __init__( + self, + config: BitnetConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList( + [ + BitnetDecoderLayer( + config=config, cache_config=cache_config, quant_config=quant_config + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class BitnetForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: BitnetConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + + self.model = BitnetModel(config, cache_config, quant_config) + self.unpadded_vocab_size = config.vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor( + self.lm_head.weight, hidden_states, sampling_metadata + ) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace(".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + f"Found kv scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_kv_scale_name}). kv-scale is " + "not loaded." + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index f5684dbf1271..62613ec61e75 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -96,11 +96,21 @@ def get_tokenizer( revision=revision, **kwargs) except ValueError as e: - # If the error pertains to the tokenizer class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - if (not trust_remote_code and + if "BitnetTokenizer" in str(e): + # This is for the error "'BitnetTokenizer' object has no + # attribute 'sp_model'". + from vllm.transformers_utils.tokenizers.bitnet import BitnetTokenizer + tokenizer = BitnetTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + elif (not trust_remote_code and ("does not exist or is not currently imported." in str(e) or "requires you to execute the tokenizer file" in str(e))): + # If the error pertains to the tokenizer class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index e6b59722c259..090691dade27 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,5 +1,7 @@ from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer +from vllm.transformers_utils.tokenizers.bitnet import BitnetTokenizer __all__ = [ "BaichuanTokenizer", + "BitnetTokenizer" ] diff --git a/vllm/transformers_utils/tokenizers/bitnet.py b/vllm/transformers_utils/tokenizers/bitnet.py new file mode 100644 index 000000000000..8c25bf900747 --- /dev/null +++ b/vllm/transformers_utils/tokenizers/bitnet.py @@ -0,0 +1,504 @@ +# Adapted from +# https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/tokenization_bitnet.py + +"""Tokenization classes for Bitnet.""" +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "hf-internal-testing/llama-tokenizer": 2048, +} +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class BitnetTokenizer(PreTrainedTokenizer): + """ + Construct a Bitnet tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Bitnet should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + legacy=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = ( + AddedToken(bos_token, normalized=False, special=True) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, normalized=False, special=True) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, normalized=False, special=True) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, normalized=False, special=True) + if isinstance(pad_token, str) + else pad_token + ) + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565" + ) + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf( + f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)" + ) + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if ( + len(tokens) > 1 + and tokens[0] == SPIECE_UNDERLINE + and tokens[1] in self.all_special_tokens + ): + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return ( + tokens[self.unk_token_length :] + if len(tokens) >= self.unk_token_length + else tokens + ) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + @property + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + + The reference for this chat template is [this code + snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) + in the original repository. + """ + logger.warning_once( + "\nNo chat template is defined for this tokenizer - using the default template " + f"for the {self.__class__.__name__} class. If the default is not appropriate for " + "your model, please set `tokenizer.chat_template` to an appropriate template. " + "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n" + ) + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + template = template.replace( + "USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false" + ) + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template From dfa6b2f96ed8bf48c14c2501bea4107475f4bbc8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 16 Jul 2024 03:20:18 +0000 Subject: [PATCH 04/24] Lint Fix --- vllm/config.py | 16 +- .../layers/quantization/bitnet_bitblas.py | 49 +++-- .../layers/quantization/gptq_bitblas.py | 4 +- vllm/model_executor/models/bitnet.py | 178 ++++++++---------- vllm/transformers_utils/tokenizer.py | 6 +- .../transformers_utils/tokenizers/__init__.py | 5 +- vllm/transformers_utils/tokenizers/bitnet.py | 101 ++++------ 7 files changed, 169 insertions(+), 190 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index dce48eeda56b..c842e085f433 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -302,15 +302,15 @@ def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: Find the closest head dimension to the given head dimension that is supported by Flash Attention. """ from vllm.attention.backends.flash_attn import FlashAttentionBackend - FLASHATTN_SUPPORTED_HEAD_DIMS = FlashAttentionBackend.get_supported_head_sizes() + FLASHATTN_SUPPORTED_HEAD_DIMS = FlashAttentionBackend.get_supported_head_sizes( + ) for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS: if head_dim <= supported_head_dim: return supported_head_dim raise ValueError( f"Head dimension {head_dim} is not supported by Flash Attention. Supported head dimensions are " - f"{FLASHATTN_SUPPORTED_HEAD_DIMS}." - ) + f"{FLASHATTN_SUPPORTED_HEAD_DIMS}.") def get_head_size(self) -> int: # TODO remove hard code @@ -319,10 +319,12 @@ def get_head_size(self) -> int: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 - if hasattr(self.hf_text_config, "architectures" - ) and 'BitnetForCausalLM' in self.hf_text_config.architectures: - return self.find_flash_attn_supported_head_dims((self.hf_text_config.hidden_size // - self.hf_text_config.num_attention_heads)) + if hasattr( + self.hf_text_config, "architectures" + ) and 'BitnetForCausalLM' in self.hf_text_config.architectures: + return self.find_flash_attn_supported_head_dims( + (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads)) if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 0b3f7ca14526..03d7b3ea862d 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -60,7 +60,9 @@ def __init__(self, weight_bits: int, is_sym: bool) -> None: self.nbits = weight_bits def __repr__(self) -> str: - return (f"BITNETBitBLASConfig(weight_bits={self.weight_bits}, is_sym={self.is_sym})") + return ( + f"BITNETBitBLASConfig(weight_bits={self.weight_bits}, is_sym={self.is_sym})" + ) @classmethod def get_name(cls) -> str: @@ -98,10 +100,11 @@ def override_quantization_method(cls, hf_quant_cfg, return cls.get_name() if can_convert and user_quant == "bitnet": - logger.info("Detected that the model can run with bitnet_bitblas" - ", however you specified quantization=bitnet explicitly," - " so forcing bitnet. Use quantization=bitnet_bitblas for" - " faster inference") + logger.info( + "Detected that the model can run with bitnet_bitblas" + ", however you specified quantization=bitnet explicitly," + " so forcing bitnet. Use quantization=bitnet_bitblas for" + " faster inference") return None def get_quant_method( @@ -289,7 +292,9 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) if enable_tuning: bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) @@ -319,18 +324,26 @@ def activation_quant(self, x, num_bits=8): result = (x * s).round().clamp(Qn, Qp) return result.type(torch.int8) - def repack_bitblas_from_bitnet(self, b_q_weight: torch.Tensor, is_qkv_packed: bool=False): + def repack_bitblas_from_bitnet(self, + b_q_weight: torch.Tensor, + is_qkv_packed: bool = False): if is_qkv_packed: hidden_size = b_q_weight.size(0) - sw_q = 1 / b_q_weight[:hidden_size // 3].abs().mean().clamp(min=1e-5) - sw_k = 1 / b_q_weight[hidden_size // 3:2 * hidden_size // 3].abs().mean().clamp(min=1e-5) - sw_v = 1 / b_q_weight[2 * hidden_size // 3:].abs().mean().clamp(min=1e-5) + sw_q = 1 / b_q_weight[:hidden_size // + 3].abs().mean().clamp(min=1e-5) + sw_k = 1 / b_q_weight[hidden_size // 3:2 * hidden_size // + 3].abs().mean().clamp(min=1e-5) + sw_v = 1 / b_q_weight[2 * hidden_size // + 3:].abs().mean().clamp(min=1e-5) self.sw_q = sw_q self.sw_k = sw_k self.sw_v = sw_v - qweight_q = self.weight_quant(b_q_weight[:hidden_size // 3]).detach() - qweight_k = self.weight_quant(b_q_weight[hidden_size // 3:2 * hidden_size // 3]).detach() - qweight_v = self.weight_quant(b_q_weight[2 * hidden_size // 3:]).detach() + qweight_q = self.weight_quant(b_q_weight[:hidden_size // + 3]).detach() + qweight_k = self.weight_quant( + b_q_weight[hidden_size // 3:2 * hidden_size // 3]).detach() + qweight_v = self.weight_quant(b_q_weight[2 * hidden_size // + 3:]).detach() qweight = torch.cat([qweight_q, qweight_k, qweight_v], dim=0) else: sw = 1 / b_q_weight.abs().mean().clamp(min=1e-5) @@ -348,8 +361,9 @@ def apply( part_size_n = layer.output_size_per_partition out_shape = x.shape[:-1] + (part_size_n, ) - quant_input = self.activation_quant(x, self.quant_config.input_bits).detach() - + quant_input = self.activation_quant( + x, self.quant_config.input_bits).detach() + if layer.bitblas_state == BITNETBitBLASState.REPACK: layer.bitblas_state = BITNETBitBLASState.READY @@ -358,6 +372,7 @@ def apply( def free_tensor(name): # free the original weight tensor delattr(layer, name) + def replace_tensor(name, new_t): # Cannot use copy_() because the storage shape and dtype are different # del layer._parameters[name] @@ -365,9 +380,7 @@ def replace_tensor(name, new_t): setattr(layer, name, new_t) # Repack weights - bitblas_qweight = self.repack_bitblas_from_bitnet( - layer.weight, - ) + bitblas_qweight = self.repack_bitblas_from_bitnet(layer.weight, ) # free the original weight tensor free_tensor("weight") replace_tensor("qweight", bitblas_qweight) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 75f9b087f07a..14df921e710a 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -403,7 +403,9 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) if enable_tuning: bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index e35ff4230762..dbb4ddb2e3ac 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -209,22 +209,22 @@ def _rope_scaling_validation(self): if self.rope_scaling is None: return - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + if not isinstance(self.rope_scaling, + dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " - f"got {self.rope_scaling}" - ) + f"got {self.rope_scaling}") rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", "dynamic" + ]: raise ValueError( f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) - if ( - rope_scaling_factor is None - or not isinstance(rope_scaling_factor, float) - or rope_scaling_factor <= 1.0 - ): + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0): raise ValueError( f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}" ) @@ -255,12 +255,11 @@ def __init__( quant_config=quant_config, ) if hidden_act != "silu": - raise ValueError( - f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now." - ) + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") self.act_fn = SiluAndMul() - self.ffn_layernorm = BitnetRMSNorm(intermediate_size, eps=config.rms_norm_eps) + self.ffn_layernorm = BitnetRMSNorm(intermediate_size, + eps=config.rms_norm_eps) def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -271,6 +270,7 @@ def forward(self, x): class BitnetRotaryEmbedding(nn.Module): + def __init__( self, dim, @@ -284,29 +284,24 @@ def __init__( self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) - / self.dim - ) - ) + inv_freq = 1.0 / (self.base**(torch.arange( + 0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=torch.int64 - ).type_as(self.inv_freq) + t = torch.arange(self.max_seq_len_cached, + device=device, + dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False - ) - self.register_buffer( - "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False - ) + self.register_buffer("_cos_cached", + emb.cos().to(torch.get_default_dtype()), + persistent=False) + self.register_buffer("_sin_cached", + emb.sin().to(torch.get_default_dtype()), + persistent=False) @property def sin_cached(self): @@ -327,22 +322,17 @@ def cos_cached(self): @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) + inv_freq_expanded = (self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1)) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) + device_type = (device_type if isinstance(device_type, str) + and device_type != "mps" else "cpu") with torch.autocast(device_type=device_type, enabled=False): - freqs = ( - inv_freq_expanded.float() @ position_ids_expanded.float() - ).transpose(1, 2) + freqs = (inv_freq_expanded.float() + @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() @@ -382,7 +372,8 @@ def __init__( self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_groups = self.num_heads // self.num_kv_heads self.head_dim = hidden_size // self.total_num_heads - self.padded_head_dim = self.find_flash_attn_supported_head_dims(self.head_dim) + self.padded_head_dim = self.find_flash_attn_supported_head_dims( + self.head_dim) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -435,7 +426,8 @@ def __init__( rope_scaling=rope_scaling, ) - self.inner_attn_ln = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.inner_attn_ln = BitnetRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: """ @@ -443,14 +435,14 @@ def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: """ from vllm.attention.backends.flash_attn import FlashAttentionBackend - FLASHATTN_SUPPORTED_HEAD_DIMS = FlashAttentionBackend.get_supported_head_sizes() + FLASHATTN_SUPPORTED_HEAD_DIMS = FlashAttentionBackend.get_supported_head_sizes( + ) for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS: if head_dim <= supported_head_dim: return supported_head_dim raise ValueError( f"Head dimension {head_dim} is not supported by Flash Attention. Supported head dimensions are " - f"{FLASHATTN_SUPPORTED_HEAD_DIMS}." - ) + f"{FLASHATTN_SUPPORTED_HEAD_DIMS}.") def forward( self, @@ -478,15 +470,17 @@ def forward( (0, self.padded_head_dim - self.head_dim), ).view(-1, self.total_num_kv_heads * self.padded_head_dim) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - attn_output = attn_output.view(-1, self.total_num_heads, self.padded_head_dim)[ - ..., : self.head_dim - ].reshape(-1, self.total_num_heads * self.head_dim) + attn_output = attn_output.view( + -1, self.total_num_heads, + self.padded_head_dim)[..., :self.head_dim].reshape( + -1, self.total_num_heads * self.head_dim) attn_output = self.inner_attn_ln(attn_output) output, _ = self.o_proj(attn_output) return output class BitnetRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ BitnetRMSNorm is equivalent to T5LayerNorm @@ -499,7 +493,8 @@ def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) @@ -516,22 +511,19 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None - ): + config, "original_max_position_embeddings", None): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings - ) - max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False - ) + config, "bias", False) self.self_attn = BitnetAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr( - config, "num_key_value_heads", config.num_attention_heads - ), + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -548,12 +540,10 @@ def __init__( bias=getattr(config, "mlp_bias", False), config=config, ) - self.input_layernorm = BitnetRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = BitnetRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = BitnetRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = BitnetRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -598,14 +588,12 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList( - [ - BitnetDecoderLayer( - config=config, cache_config=cache_config, quant_config=quant_config - ) - for _ in range(config.num_hidden_layers) - ] - ) + self.layers = nn.ModuleList([ + BitnetDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -677,9 +665,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size, logit_scale - ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) self.sampler = Sampler() def forward( @@ -689,15 +676,14 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> torch.Tensor: - logits = self.logits_processor( - self.lm_head.weight, hidden_states, sampling_metadata - ) + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) return logits def sample( @@ -739,19 +725,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace(".kv_scale", ".attn.kv_scale") + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") if remapped_kv_scale_name not in params_dict: print_warning_once( f"Found kv scale in the checkpoint (e.g. {name}), " "but not found the expected name in the model " f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded." - ) + "not loaded.") continue else: name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) # If this function is called, it should always initialize KV cache scale @@ -761,11 +748,11 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, ): layer_self_attn = self.model.layers[layer_idx].self_attn @@ -778,6 +765,5 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if hasattr(layer_self_attn, "kv_scale"): layer_self_attn.attn._kv_scale = scaling_factor else: - raise RuntimeError( - "Self attention has no KV cache scaling " "factor attribute!" - ) + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 62613ec61e75..d8ca38e6191b 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -106,9 +106,9 @@ def get_tokenizer( trust_remote_code=trust_remote_code, revision=revision, **kwargs) - elif (not trust_remote_code and - ("does not exist or is not currently imported." in str(e) - or "requires you to execute the tokenizer file" in str(e))): + elif (not trust_remote_code + and ("does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e))): # If the error pertains to the tokenizer class not existing or not # currently being imported, suggest using the --trust-remote-code flag. err_msg = ( diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index 090691dade27..1c8cd7660ea6 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,7 +1,4 @@ from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer from vllm.transformers_utils.tokenizers.bitnet import BitnetTokenizer -__all__ = [ - "BaichuanTokenizer", - "BitnetTokenizer" -] +__all__ = ["BaichuanTokenizer", "BitnetTokenizer"] diff --git a/vllm/transformers_utils/tokenizers/bitnet.py b/vllm/transformers_utils/tokenizers/bitnet.py index 8c25bf900747..203d56632d57 100644 --- a/vllm/transformers_utils/tokenizers/bitnet.py +++ b/vllm/transformers_utils/tokenizers/bitnet.py @@ -1,6 +1,5 @@ # Adapted from # https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/tokenization_bitnet.py - """Tokenization classes for Bitnet.""" import os from shutil import copyfile @@ -12,7 +11,6 @@ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer from transformers.utils import logging - if TYPE_CHECKING: from transformers.tokenization_utils_base import TextInput @@ -22,10 +20,12 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": + "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": + "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -142,26 +142,14 @@ def __init__( **kwargs, ): self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = ( - AddedToken(bos_token, normalized=False, special=True) - if isinstance(bos_token, str) - else bos_token - ) - eos_token = ( - AddedToken(eos_token, normalized=False, special=True) - if isinstance(eos_token, str) - else eos_token - ) - unk_token = ( - AddedToken(unk_token, normalized=False, special=True) - if isinstance(unk_token, str) - else unk_token - ) - pad_token = ( - AddedToken(pad_token, normalized=False, special=True) - if isinstance(pad_token, str) - else pad_token - ) + bos_token = (AddedToken(bos_token, normalized=False, special=True) + if isinstance(bos_token, str) else bos_token) + eos_token = (AddedToken(eos_token, normalized=False, special=True) + if isinstance(eos_token, str) else eos_token) + unk_token = (AddedToken(unk_token, normalized=False, special=True) + if isinstance(unk_token, str) else unk_token) + pad_token = (AddedToken(pad_token, normalized=False, special=True) + if isinstance(pad_token, str) else pad_token) if legacy is None: logger.warning_once( @@ -169,8 +157,7 @@ def __init__( " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" " means, and thoroughly read the reason why this was added as explained in" - " https://github.com/huggingface/transformers/pull/24565" - ) + " https://github.com/huggingface/transformers/pull/24565") legacy = True self.legacy = legacy @@ -239,7 +226,10 @@ def vocab_size(self): def get_vocab(self): """Returns vocab as a dict""" - vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab = { + self.convert_ids_to_tokens(i): i + for i in range(self.vocab_size) + } vocab.update(self.added_tokens_encoder) return vocab @@ -258,11 +248,8 @@ def tokenize(self, text: "TextInput", **kwargs) -> List[str]: tokens = super().tokenize(text, **kwargs) - if ( - len(tokens) > 1 - and tokens[0] == SPIECE_UNDERLINE - and tokens[1] in self.all_special_tokens - ): + if (len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE + and tokens[1] in self.all_special_tokens): tokens = tokens[1:] return tokens @@ -284,11 +271,8 @@ def _tokenize(self, text, **kwargs): # 1. Encode string + prefix ex: " Hey" tokens = self.sp_model.encode(self.unk_token + text, out_type=str) # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] - return ( - tokens[self.unk_token_length :] - if len(tokens) >= self.unk_token_length - else tokens - ) + return (tokens[self.unk_token_length:] + if len(tokens) >= self.unk_token_length else tokens) def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" @@ -322,9 +306,9 @@ def convert_tokens_to_string(self, tokens): out_string += self.sp_model.decode(current_sub_tokens) return out_string - def save_vocabulary( - self, save_directory, filename_prefix: Optional[str] = None - ) -> Tuple[str]: + def save_vocabulary(self, + save_directory, + filename_prefix: Optional[str] = None) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. @@ -336,24 +320,24 @@ def save_vocabulary( `Tuple(str)`: Paths to the files saved. """ if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") + logger.error( + f"Vocabulary path ({save_directory}) should be a directory") return out_vocab_file = os.path.join( save_directory, - (filename_prefix + "-" if filename_prefix else "") - + VOCAB_FILES_NAMES["vocab_file"], + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], ) if os.path.abspath(self.vocab_file) != os.path.abspath( - out_vocab_file - ) and os.path.isfile(self.vocab_file): + out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: content_spiece_model = self.sp_model.serialized_model_proto() fi.write(content_spiece_model) - return (out_vocab_file,) + return (out_vocab_file, ) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): bos_token_id = [self.bos_token_id] if self.add_bos_token else [] @@ -399,18 +383,13 @@ def get_special_tokens_mask( if token_ids_1 is None: return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return ( - bos_token_id - + ([0] * len(token_ids_0)) - + eos_token_id - + bos_token_id - + ([0] * len(token_ids_1)) - + eos_token_id - ) + return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id) def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: """ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT sequence pair mask has the following format: @@ -493,12 +472,12 @@ def default_chat_template(self): "{% elif message['role'] == 'assistant' %}" "{{ ' ' + content.strip() + ' ' + eos_token }}" "{% endif %}" - "{% endfor %}" - ) + "{% endfor %}") template = template.replace( - "USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false" - ) - default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + "USE_DEFAULT_PROMPT", + "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace( + "'", "\\'") template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) return template From 8d2c6359bbf70ff13cadf4efbceda8d5f9250e42 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 16 Jul 2024 03:30:48 +0000 Subject: [PATCH 05/24] lint fix --- vllm/transformers_utils/tokenizers/bitnet.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/transformers_utils/tokenizers/bitnet.py b/vllm/transformers_utils/tokenizers/bitnet.py index 203d56632d57..32a9f22c457f 100644 --- a/vllm/transformers_utils/tokenizers/bitnet.py +++ b/vllm/transformers_utils/tokenizers/bitnet.py @@ -289,7 +289,7 @@ def convert_tokens_to_string(self, tokens): if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: tokens[0] = tokens[0][1:] - current_sub_tokens = [] + current_sub_tokens: List[str] = [] out_string = "" prev_is_special = False for i, token in enumerate(tokens): @@ -306,9 +306,10 @@ def convert_tokens_to_string(self, tokens): out_string += self.sp_model.decode(current_sub_tokens) return out_string - def save_vocabulary(self, - save_directory, - filename_prefix: Optional[str] = None) -> Tuple[str]: + def save_vocabulary( + self, + save_directory, + filename_prefix: Optional[str] = None) -> Optional[Tuple[str]]: """ Save the vocabulary and special tokens file to a directory. @@ -322,7 +323,7 @@ def save_vocabulary(self, if not os.path.isdir(save_directory): logger.error( f"Vocabulary path ({save_directory}) should be a directory") - return + return None out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + From 41bb18e15044d89256ad70db8ff7417e03a3e343 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 16 Jul 2024 05:07:07 +0000 Subject: [PATCH 06/24] Lint Fix for line length --- vllm/config.py | 212 ++++++++++-------- vllm/model_executor/layers/linear.py | 7 +- .../layers/quantization/__init__.py | 9 +- .../layers/quantization/bitblas.py | 52 +++-- .../layers/quantization/bitnet_bitblas.py | 68 +++--- .../layers/quantization/gptq_bitblas.py | 75 ++++--- vllm/model_executor/models/bitnet.py | 41 ++-- vllm/transformers_utils/tokenizer.py | 8 +- vllm/transformers_utils/tokenizers/bitnet.py | 7 +- 9 files changed, 266 insertions(+), 213 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c842e085f433..863ecf6b0d47 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,8 +33,8 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. - It is also used as the content for `model_name` tag in metrics - output when `served_model_name` is not specified. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. @@ -81,8 +81,8 @@ class ModelConfig: skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, - matches the model name exposed via the APIs. If multiple model - names provided, the first name will be used. If not specified, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, the model name will be the same as `model`. """ @@ -138,8 +138,14 @@ def __init__( self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init - self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, rope_scaling, rope_theta) + self.hf_config = get_config( + self.model, + trust_remote_code, + revision, + code_revision, + rope_scaling, + rope_theta, + ) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -157,7 +163,8 @@ def __init__( hf_config=self.hf_text_config, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, - sliding_window_len=self.get_hf_config_sliding_window()) + sliding_window_len=self.get_hf_config_sliding_window(), + ) self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = multimodal_config @@ -224,18 +231,25 @@ def _verify_quantization(self) -> None: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") - if is_hip( - ) and self.quantization not in rocm_supported_quantization: + if (is_hip() + and self.quantization not in rocm_supported_quantization): raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if (self.quantization - not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "gptq_bitblas", "bitblas")): + if self.quantization not in ( + "fp8", + "marlin", + "gptq_marlin_24", + "gptq_marlin", + "gptq_bitblas", + "bitblas", + ): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " - "non-quantized models.", self.quantization) + "non-quantized models.", + self.quantization, + ) def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: @@ -283,8 +297,7 @@ def get_hf_config_sliding_window(self) -> Optional[int]: return getattr(self.hf_text_config, "sliding_window", None) def get_sliding_window(self) -> Optional[int]: - """Get the sliding window size, or None if disabled. - """ + """Get the sliding window size, or None if disabled.""" # If user disables sliding window, return None. if self.disable_sliding_window: return None @@ -299,29 +312,32 @@ def get_hidden_size(self) -> int: def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: """ - Find the closest head dimension to the given head dimension that is supported by Flash Attention. + Find the closest head dimension to the given head dimension that + is supported by Flash Attention. """ from vllm.attention.backends.flash_attn import FlashAttentionBackend - FLASHATTN_SUPPORTED_HEAD_DIMS = FlashAttentionBackend.get_supported_head_sizes( - ) + + FLASHATTN_SUPPORTED_HEAD_DIMS = ( + FlashAttentionBackend.get_supported_head_sizes()) for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS: if head_dim <= supported_head_dim: return supported_head_dim raise ValueError( - f"Head dimension {head_dim} is not supported by Flash Attention. Supported head dimensions are " - f"{FLASHATTN_SUPPORTED_HEAD_DIMS}.") + f"Head dimension {head_dim} is not supported by Flash Attention." + f"Supported head dimensions are {FLASHATTN_SUPPORTED_HEAD_DIMS}.") def get_head_size(self) -> int: # TODO remove hard code - if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type == 'deepseek_v2': + if (hasattr(self.hf_text_config, "model_type") + and self.hf_text_config.model_type == "deepseek_v2"): # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 - if hasattr( - self.hf_text_config, "architectures" - ) and 'BitnetForCausalLM' in self.hf_text_config.architectures: + if (hasattr(self.hf_text_config, "architectures") + and "BitnetForCausalLM" in self.hf_text_config.architectures): + # FlashAttention does not support head_size 100 + # TODO: implement for head_size 100 return self.find_flash_attn_supported_head_dims( (self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads)) @@ -354,8 +370,11 @@ def get_total_num_kv_heads(self) -> int: return self.hf_config.attn_config["kv_n_heads"] return self.hf_config.num_attention_heads if self.hf_config.model_type == "dbrx": - return getattr(self.hf_config.attn_config, "kv_n_heads", - self.hf_config.num_attention_heads) + return getattr( + self.hf_config.attn_config, + "kv_n_heads", + self.hf_config.num_attention_heads, + ) attributes = [ # For Falcon: @@ -499,6 +518,7 @@ class TokenizerPoolConfig: The way the config will be used depends on the pool type. """ + pool_size: int pool_type: str extra_config: dict @@ -511,8 +531,10 @@ def __post_init__(self): @classmethod def create_config( - cls, tokenizer_pool_size: int, tokenizer_pool_type: str, - tokenizer_pool_extra_config: Optional[Union[str, dict]] + cls, + tokenizer_pool_size: int, + tokenizer_pool_type: str, + tokenizer_pool_extra_config: Optional[Union[str, dict]], ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. @@ -532,9 +554,11 @@ def create_config( else: tokenizer_pool_extra_config_parsed = ( tokenizer_pool_extra_config or {}) - tokenizer_pool_config = cls(tokenizer_pool_size, - tokenizer_pool_type, - tokenizer_pool_extra_config_parsed) + tokenizer_pool_config = cls( + tokenizer_pool_size, + tokenizer_pool_type, + tokenizer_pool_extra_config_parsed, + ) else: tokenizer_pool_config = None return tokenizer_pool_config @@ -554,20 +578,20 @@ class LoadFormat(str, enum.Enum): @dataclass class LoadConfig: """ - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. - "tensorizer" will use CoreWeave's tensorizer library for - fast weight loading. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO @@ -659,6 +683,7 @@ def __init__( # current node and we aren't in a ray placement group. from vllm.executor import ray_utils + backend = "mp" ray_found = ray_utils.ray is not None if cuda_device_count_stateless() < self.world_size: @@ -671,8 +696,10 @@ def __init__( backend = "ray" else: from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): from ray.util import get_current_placement_group + if get_current_placement_group(): backend = "ray" self.distributed_executor_backend = backend @@ -732,7 +759,7 @@ class SchedulerConfig: enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. embedding_mode: Whether the running model is for embedding. - preemption_mode: Whether to perform preemption by swapping or + preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than swapping. However, when the sequence group has multiple sequences @@ -740,16 +767,18 @@ class SchedulerConfig: such a case, we use swapping instead. """ - def __init__(self, - max_num_batched_tokens: Optional[int], - max_num_seqs: int, - max_model_len: int, - use_v2_block_manager: bool = False, - num_lookahead_slots: int = 0, - delay_factor: float = 0.0, - enable_chunked_prefill: bool = False, - embedding_mode: Optional[bool] = False, - preemption_mode: Optional[str] = None) -> None: + def __init__( + self, + max_num_batched_tokens: Optional[int], + max_num_seqs: int, + max_model_len: int, + use_v2_block_manager: bool = False, + num_lookahead_slots: int = 0, + delay_factor: float = 0.0, + enable_chunked_prefill: bool = False, + embedding_mode: Optional[bool] = False, + preemption_mode: Optional[str] = None, + ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: @@ -973,8 +1002,8 @@ def maybe_create_spec_config( "Speculative decoding with mlp_speculator models does not " "yet support distributed inferencing (TP > 1).") - if (num_speculative_tokens is not None - and hasattr(draft_hf_config, "num_lookahead_tokens")): + if num_speculative_tokens is not None and hasattr( + draft_hf_config, "num_lookahead_tokens"): draft_hf_config.num_lookahead_tokens = num_speculative_tokens n_predict = getattr(draft_hf_config, "n_predict", None) @@ -1000,7 +1029,8 @@ def maybe_create_spec_config( draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( target_parallel_config, - speculative_draft_tensor_parallel_size)) + speculative_draft_tensor_parallel_size, + )) if num_speculative_tokens is None: raise ValueError( @@ -1055,15 +1085,15 @@ def _maybe_override_draft_max_model_len( @staticmethod def create_draft_parallel_config( target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int] + speculative_draft_tensor_parallel_size: Optional[int], ) -> ParallelConfig: """Create a parallel config for use by the draft worker. This is mostly a copy of the target parallel config, except the tp_size. """ if speculative_draft_tensor_parallel_size is None: - speculative_draft_tensor_parallel_size = \ - target_parallel_config.tensor_parallel_size + speculative_draft_tensor_parallel_size = ( + target_parallel_config.tensor_parallel_size) elif speculative_draft_tensor_parallel_size != 1: # TODO(wooyeon): allow tp values larger than 1 raise ValueError( @@ -1113,8 +1143,8 @@ def __init__( self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens - self.speculative_disable_by_batch_size = \ - speculative_disable_by_batch_size + self.speculative_disable_by_batch_size = ( + speculative_disable_by_batch_size) self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 @@ -1187,11 +1217,14 @@ def verify_with_model_config(self, model_config: ModelConfig): elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) if model_config.quantization and model_config.quantization not in [ - "awq", "gptq" + "awq", + "gptq", ]: # TODO support marlin and squeezellm - logger.warning("%s quantization is not tested with LoRA yet.", - model_config.quantization) + logger.warning( + "%s quantization is not tested with LoRA yet.", + model_config.quantization, + ) def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.max_num_batched_tokens > 65528: @@ -1221,6 +1254,7 @@ class ImageInputType(enum.Enum): For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336). IMAGE_FEATURES: (1, 576, 1024). """ + PIXEL_VALUES = enum.auto() IMAGE_FEATURES = enum.auto() @@ -1246,11 +1280,11 @@ def get_image_input_enum_type(cls, value: str) -> ImageInputType: f"Expecting to choose from " f"{[x.name for x in cls.ImageInputType]}.") from e - #TODO(ywang96): make this a cached property once we refactor the + # TODO(ywang96): make this a cached property once we refactor the # VisionLanguageConfig class. def get_image_token_text( self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: - """Get the image token placeholder text to be inserted into the + """Get the image token placeholder text to be inserted into the text prompt and the string representation of the image token id. """ image_token_str = tokenizer.decode(self.image_token_id) @@ -1368,15 +1402,16 @@ def _get_and_verify_max_len( for key in possible_keys: max_len = getattr(hf_config, key, None) if max_len is not None: - max_len_key = key if max_len < derived_max_model_len \ - else max_len_key + max_len_key = (key + if max_len < derived_max_model_len else max_len_key) derived_max_model_len = min(derived_max_model_len, max_len) # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. if disable_sliding_window and sliding_window_len is not None: - max_len_key = "sliding_window" \ - if sliding_window_len < derived_max_model_len else max_len_key + max_len_key = ("sliding_window" + if sliding_window_len < derived_max_model_len else + max_len_key) derived_max_model_len = min(derived_max_model_len, sliding_window_len) # If none of the keys were found in the config, use a default and @@ -1390,15 +1425,17 @@ def _get_and_verify_max_len( logger.warning( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " - "%s. Assuming the model's maximum length is %d.", possible_keys, - default_max_len) + "%s. Assuming the model's maximum length is %d.", + possible_keys, + default_max_len, + ) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) # The correct one should be "longrope", kept "su" here # to be backward compatible - if rope_scaling is not None and rope_scaling["type"] != "su" \ - and rope_scaling["type"] != "longrope": + if (rope_scaling is not None and rope_scaling["type"] != "su" + and rope_scaling["type"] != "longrope"): if disable_sliding_window: # TODO(robertgshaw): Find a model that supports rope_scaling # with sliding window to see if this case should be allowed. @@ -1445,10 +1482,10 @@ def _get_and_verify_max_len( def get_served_model_name(model: str, served_model_name: Optional[Union[str, List[str]]]): """ - If the input is a non-empty list, the first model_name in - `served_model_name` is taken. - If the input is a non-empty string, it is used directly. - For cases where the input is either an empty string or an + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an empty list, the fallback is to use `self.model`. """ if not served_model_name: @@ -1463,10 +1500,10 @@ class DecodingConfig: """Dataclass which contains the decoding strategy of the engine""" # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' - guided_decoding_backend: str = 'outlines' + guided_decoding_backend: str = "outlines" def __post_init__(self): - valid_guided_backends = ['outlines', 'lm-format-enforcer'] + valid_guided_backends = ["outlines", "lm-format-enforcer"] backend = self.guided_decoding_backend if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}," @@ -1476,6 +1513,7 @@ def __post_init__(self): @dataclass class ObservabilityConfig: """Configuration for observability.""" + otlp_traces_endpoint: Optional[str] = None def __post_init__(self): @@ -1503,8 +1541,7 @@ class EngineConfig: observability_config: Optional[ObservabilityConfig] def __post_init__(self): - """Verify configs are valid & consistent with each other. - """ + """Verify configs are valid & consistent with each other.""" self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) @@ -1514,7 +1551,6 @@ def __post_init__(self): self.scheduler_config) def to_dict(self): - """Return the configs as a dictionary, for use in **kwargs. - """ + """Return the configs as a dictionary, for use in **kwargs.""" return dict( (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c567ad9d649f..db9a168daa86 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -314,8 +314,11 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert param_data.dtype == loaded_weight.dtype, f"{param.data.dtype} != {loaded_weight.dtype}" - assert param_data.shape == loaded_weight.shape, f"{param_data.shape} != {loaded_weight.shape}" + assert param_data.dtype == loaded_weight.dtype, ( + f"{param_data.dtype} != {loaded_weight.dtype}") + assert param_data.shape == loaded_weight.shape, ( + f"{param_data.shape} != {loaded_weight.shape}") + param_data.copy_(loaded_weight) def forward(self, input_): diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 4585ae98bb0b..659e9548a67a 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,23 +4,24 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.bitsandbytes import ( - BitsAndBytesConfig) +from vllm.model_executor.layers.quantization.bitblas import BitBLASConfig from vllm.model_executor.layers.quantization.bitnet_bitblas import ( BITNETBitBLASConfig) +from vllm.model_executor.layers.quantization.bitsandbytes import ( + BitsAndBytesConfig) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_bitblas import ( + GPTQBitBLASConfig) from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) -from vllm.model_executor.layers.quantization.gptq_bitblas import GPTQBitBLASConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.bitblas import BitBLASConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 9cad817608aa..f1475818ab0a 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -3,7 +3,6 @@ import torch from torch.nn.parameter import Parameter -from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( @@ -14,14 +13,13 @@ try: import bitblas + from bitblas.utils import auto_detect_nvidia_target except ImportError as e: bitblas_import_exception = e raise ValueError( - f"Trying to use the bitblas backend, but could not import dependencies with the following error: {bitblas_import_exception}" - ) - -import bitblas -from bitblas.utils import auto_detect_nvidia_target + "Trying to use the bitblas backend, but could not import dependencies" + f"with the following error: {bitblas_import_exception}" + ) from bitblas_import_exception BITBLAS_TARGET = auto_detect_nvidia_target() BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() @@ -38,7 +36,9 @@ class BitBLASConfig(QuantizationConfig): TORCH_DTYPE = torch.float16 STORAGE_DTYPE = "int8" # assume int8 storage TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) - ZEROS_MODE = "quantized" # "original" or "rescale" or "quantized", the gptq_bitblas prefer "quantized" + # "original" or "rescale" or "quantized", + # gptq_with_bitblas prefer "quantized implementation" + ZEROS_MODE = "quantized" def __init__(self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool) -> None: @@ -165,7 +165,8 @@ def create_weights( ): """Creates quantized weights for use in linear operations. - The function initializes and returns a dictionary containing quantized weights, scales, and zeros + The function initializes and returns a dictionary containing quantized + weights, scales, and zeros for performing quantized matrix multiplication operations. Args: @@ -173,28 +174,31 @@ def create_weights( output_size_per_partition: The size of the output partition. input_size: The total size of the input (unused). output_size: The total size of the output (unused). - params_dtype: The data type of the parameters (expected to be torch.float16). + params_dtype: + The data type of the parameters (expected to be torch.float16). Returns: - A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or if the input size per partition - is not divisible by the group size in `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or if the + input size per partition is not divisible by the group size in + `quant_config`. """ del input_size, output_size # Unused arguments. if params_dtype != torch.float16: - raise ValueError( - f"Parameter data type must be torch.float16, but got {params_dtype}" - ) + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) if (self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0): raise ValueError( - f"Input size per partition ({input_size_per_partition}) must be divisible by " - f"group size ({self.quant_config.group_size}).") + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({self.quant_config.group_size})." + ) # Initialize or retrieve the BitBLAS matrix multiplication operator. self._configure_bitblas_matmul( @@ -206,7 +210,8 @@ def create_weights( layout="nt", bits=self.quant_config.weight_bits, ) - # Initialize quantized weights with dimensions optimized for BitBLAS operations. + + # Initialize quantized weights with dimensions qweight = Parameter( torch.empty( @@ -348,14 +353,15 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( BITBLAS_DATABASE_PATH, BITBLAS_TARGET) - logger.info( - "BitBLAS Tuning done, appended operator to global_operator_cache." - ) + logger.info("BitBLAS Tuning done, appended operator to " + "global_operator_cache.") else: - logger.info(f"BitBLAS Operator {config} created.") + _message = f"BitBLAS Operator {config} created." + logger.info(_message) else: - logger.info( + _message = ( f"BitBLAS Operator {config} found in global_operator_cache.") + logger.info(_message) return bitblas_matmul def apply( diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 03d7b3ea862d..0423fcefc904 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -2,30 +2,27 @@ from enum import Enum from typing import Any, Dict, List, Optional +import bitblas.cache import torch from torch.nn.parameter import Parameter -import bitblas.cache from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ( - LinearBase, - LinearMethodBase, - set_weight_attrs, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig) logger = init_logger(__name__) try: import bitblas + from bitblas.utils import auto_detect_nvidia_target except ImportError as e: bitblas_import_exception = e - raise ValueError( - f"Trying to use the bitblas backend, but could not import dependencies with the following error: {bitblas_import_exception}" - ) - -import bitblas -from bitblas.utils import auto_detect_nvidia_target + error_message = ( + "Trying to use the bitblas backend, but could not import dependencies " + f"with the following error: {bitblas_import_exception}") + raise ValueError(error_message) from bitblas_import_exception bitblas.set_log_level("Debug") BITBLAS_TARGET = auto_detect_nvidia_target() @@ -60,9 +57,8 @@ def __init__(self, weight_bits: int, is_sym: bool) -> None: self.nbits = weight_bits def __repr__(self) -> str: - return ( - f"BITNETBitBLASConfig(weight_bits={self.weight_bits}, is_sym={self.is_sym})" - ) + return (f"BITNETBitBLASConfig(weight_bits={self.weight_bits}, " + f"is_sym={self.is_sym})") @classmethod def get_name(cls) -> str: @@ -173,28 +169,31 @@ def create_weights( ) -> None: """Creates quantized weights for use in linear operations. - The function initializes and returns a dictionary containing quantized weights, scales, and zeros - for performing quantized matrix multiplication operations. + The function initializes and returns a dictionary containing + quantized weights, scales, and zeros for performing quantized + matrix multiplication operations. Args: input_size_per_partition: The size of the input partition. output_partition_sizes: The size of the output partition. input_size: The total size of the input (unused). output_size: The total size of the output (unused). - params_dtype: The data type of the parameters (expected to be torch.float16). + params_dtype: + The data type of the parameters (expected to be torch.float16). Returns: - A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or if the input size per partition - is not divisible by the group size in `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or if the + input size per partition is not divisible by the group size + in `quant_config`. """ del output_size # Unused arguments. if params_dtype != torch.float16: - raise ValueError( - f"Parameter data type must be torch.float16, but got {params_dtype}" - ) + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) @@ -300,14 +299,15 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( BITBLAS_DATABASE_PATH, BITBLAS_TARGET) - logger.info( - "BitBLAS Tuning done, appended operator to global_operator_cache." - ) + logger.info("BitBLAS Tuning done, appended operator to " + "global_operator_cache.") else: - logger.info(f"BitBLAS Operator {config} created.") + _message = ( + f"BitBLAS Operator {config} created without tuning. ") + logger.info(_message) else: - logger.info( - f"BitBLAS Operator {config} found in global_operator_cache.") + _message = (f"BitBLAS Operator {config} retrieved from cache.") + logger.info(_message) return bitblas_matmul def weight_quant(self, weight): @@ -374,13 +374,13 @@ def free_tensor(name): delattr(layer, name) def replace_tensor(name, new_t): - # Cannot use copy_() because the storage shape and dtype are different - # del layer._parameters[name] + # Cannot use copy_() as gptq because the storage + # shape and dtype are different delattr(layer, name) setattr(layer, name, new_t) # Repack weights - bitblas_qweight = self.repack_bitblas_from_bitnet(layer.weight, ) + bitblas_qweight = self.repack_bitblas_from_bitnet(layer.weights) # free the original weight tensor free_tensor("weight") replace_tensor("qweight", bitblas_qweight) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 14df921e710a..a54f839a7e52 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -2,30 +2,27 @@ from enum import Enum from typing import Any, Dict, List, Optional +import bitblas.cache import torch from torch.nn.parameter import Parameter -import bitblas.cache from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ( - LinearBase, - LinearMethodBase, - set_weight_attrs, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) logger = init_logger(__name__) try: import bitblas + from bitblas.utils import auto_detect_nvidia_target except ImportError as e: bitblas_import_exception = e raise ValueError( - f"Trying to use the bitblas backend, but could not import dependencies with the following error: {bitblas_import_exception}" - ) - -import bitblas -from bitblas.utils import auto_detect_nvidia_target + "Trying to use the bitblas backend, but could not import dependencies" + f"with the following error: {bitblas_import_exception}" + ) from bitblas_import_exception bitblas.set_log_level("Debug") BITBLAS_TARGET = auto_detect_nvidia_target() @@ -62,7 +59,9 @@ class GPTQBitBLASConfig(QuantizationConfig): ) GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE) - ZEROS_MODE = "quantized" # "original" or "rescale" or "quantized", the gptq_bitblas prefer "quantized" + # "original" or "rescale" or "quantized", + # the gptq_bitblas prefer "quantized" + ZEROS_MODE = "quantized" def __init__(self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool) -> None: @@ -168,7 +167,8 @@ def is_bitblas_compatible(cls, quant_config: Dict[str, Any]): desc_act = quant_config.get("desc_act", None) # If we cannot find the info needed in the config, cannot convert. - if num_bits is None or group_size is None or sym is None or desc_act is None: + if (num_bits is None or group_size is None or sym is None + or desc_act is None): return False # If the capability of the device is too low, cannot convert. @@ -218,7 +218,8 @@ def create_weights( ) -> None: """Creates quantized weights for use in linear operations. - The function initializes and returns a dictionary containing quantized weights, scales, and zeros + The function initializes and returns a dictionary containing + quantized weights, scales, and zeros for performing quantized matrix multiplication operations. Args: @@ -226,20 +227,22 @@ def create_weights( output_partition_sizes: The size of the output partition. input_size: The total size of the input (unused). output_size: The total size of the output (unused). - params_dtype: The data type of the parameters (expected to be torch.float16). + params_dtype: + The data type of the parameters (expected to be torch.float16). Returns: - A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). Raises: - ValueError: If `params_dtype` is not `torch.float16` or if the input size per partition - is not divisible by the group size in `quant_config`. + ValueError: If `params_dtype` is not `torch.float16` or + if the input size per partition is not divisible by the + group size in `quant_config`. """ del output_size # Unused arguments. if params_dtype != torch.float16: - raise ValueError( - f"Parameter data type must be torch.float16, but got {params_dtype}" - ) + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") # Normalize group_size if self.quant_config.group_size != -1: @@ -249,8 +252,9 @@ def create_weights( if input_size_per_partition % group_size != 0: raise ValueError( - f"Input size per partition ({input_size_per_partition}) must be divisible by " - f"group size ({self.quant_config.group_size}).") + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({self.quant_config.group_size})." + ) # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) @@ -411,27 +415,29 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( BITBLAS_DATABASE_PATH, BITBLAS_TARGET) - logger.info( - "BitBLAS Tuning done, appended operator to global_operator_cache." - ) + logger.info("BitBLAS Tuning done, appended operator to " + "global_operator_cache.") else: - logger.info(f"BitBLAS Operator {config} created.") + _message = f"BitBLAS Operator {config} created without tuning. " + logger.info(_message) else: - logger.info( - f"BitBLAS Operator {config} found in global_operator_cache.") + _message = f"BitBLAS Operator {config} retrieved from cache." + logger.info(_message) return bitblas_matmul def repack_bitblas_from_gptq(self, b_q_weight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor): from bitblas.quantization.utils import general_compress - # qweight in gptq old quant linear stored with (outfeatures, infeatures), should be transposed. + # qweight in gptq old quant linear stored with + # (outfeatures, infeatures), should be transposed. qweight = b_q_weight.T.contiguous().view( self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE) if self.bitblas_matmul.weight_transform is not None: qweight = self.bitblas_matmul.weight_transform( qweight.cpu()).cuda() - # scales in gptq old quant linear stored with (infeatures // group_size, outfeatures), should be transposed. + # scales in gptq old quant linear stored with + # (infeatures // group_size, outfeatures), should be transposed. scales = scales.T.contiguous() # qzeros should be de-quantized to int zeros. intzeros = unpack_qzeros(qzeros, @@ -451,9 +457,8 @@ def repack_bitblas_from_gptq(self, b_q_weight: torch.Tensor, self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE).contiguous( )) else: - raise ValueError( - f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}" - ) + raise ValueError("Unsupported zeros type: {}".format( + self.bitblas_matmul.config.zeros_mode)) return qweight, scales, zeros diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index dbb4ddb2e3ac..2c109b47a0ea 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -21,38 +21,35 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Bitnet model compatible with HuggingFace weights.""" + +# ruff: noqa: E501 + from typing import Dict, Iterable, List, Optional, Tuple + import torch from torch import nn from transformers.configuration_utils import PretrainedConfig + from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - RowParallelLinear, -) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, - ParallelLMHead, - VocabParallelEmbedding, -) + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - kv_cache_scales_loader, -) + default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once -from vllm.logger import init_logger logger = init_logger(__name__) @@ -435,8 +432,8 @@ def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: """ from vllm.attention.backends.flash_attn import FlashAttentionBackend - FLASHATTN_SUPPORTED_HEAD_DIMS = FlashAttentionBackend.get_supported_head_sizes( - ) + FLASHATTN_SUPPORTED_HEAD_DIMS = ( + FlashAttentionBackend.get_supported_head_sizes()) for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS: if head_dim <= supported_head_dim: return supported_head_dim @@ -451,7 +448,8 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - # QKV projection cannot be grouped as the they do not share the same scaling factor + # QKV projection cannot be grouped as the they + # do not share the same scaling factor q, _ = self.q_proj(hidden_states) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) @@ -704,7 +702,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index d8ca38e6191b..9372bf588b6f 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -99,7 +99,8 @@ def get_tokenizer( if "BitnetTokenizer" in str(e): # This is for the error "'BitnetTokenizer' object has no # attribute 'sp_model'". - from vllm.transformers_utils.tokenizers.bitnet import BitnetTokenizer + from vllm.transformers_utils.tokenizers.bitnet import ( + BitnetTokenizer) tokenizer = BitnetTokenizer.from_pretrained( tokenizer_name, *args, @@ -109,8 +110,9 @@ def get_tokenizer( elif (not trust_remote_code and ("does not exist or is not currently imported." in str(e) or "requires you to execute the tokenizer file" in str(e))): - # If the error pertains to the tokenizer class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. + # If the error pertains to the tokenizer class not existing + # or not currently being imported, suggest using the + # --trust-remote-code flag. err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " diff --git a/vllm/transformers_utils/tokenizers/bitnet.py b/vllm/transformers_utils/tokenizers/bitnet.py index 32a9f22c457f..02caf88c5f3b 100644 --- a/vllm/transformers_utils/tokenizers/bitnet.py +++ b/vllm/transformers_utils/tokenizers/bitnet.py @@ -1,12 +1,13 @@ # Adapted from # https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/tokenization_bitnet.py + +# ruff: noqa: E501 """Tokenization classes for Bitnet.""" import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import sentencepiece as spm - from transformers.convert_slow_tokenizer import import_protobuf from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer from transformers.utils import logging @@ -321,8 +322,8 @@ def save_vocabulary( `Tuple(str)`: Paths to the files saved. """ if not os.path.isdir(save_directory): - logger.error( - f"Vocabulary path ({save_directory}) should be a directory") + error_message = f"Vocabulary path ({save_directory}) should be a directory" + logger.error(error_message) return None out_vocab_file = os.path.join( save_directory, From 29ac34d1167b259c8c919854c494ab35b2d62520 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 17 Jul 2024 16:53:38 +0000 Subject: [PATCH 07/24] Support Loading 1.58B Model with BitBLAS Format --- vllm/model_executor/layers/linear.py | 8 +- .../layers/quantization/bitblas.py | 271 ++++++++++++++++-- .../layers/quantization/bitnet_bitblas.py | 8 +- vllm/model_executor/models/bitnet.py | 6 + 4 files changed, 260 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index db9a168daa86..ff57a4966f1a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -28,10 +28,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_bitblas_shard(param, shard_size, shard_offset): bitblas_tile_size = getattr(param, "bitblas_tile_size", None) - if bitblas_tile_size is None: - return shard_size, shard_offset + weight_propagation = getattr(param, "weight_propagation", None) + if weight_propagation and bitblas_tile_size is not None: + return (shard_size // bitblas_tile_size, + shard_offset // bitblas_tile_size) - return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size + return shard_size, shard_offset def adjust_bitsandbytes_shard(param: Parameter, diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index f1475818ab0a..8af5f82f891d 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -40,8 +40,9 @@ class BitBLASConfig(QuantizationConfig): # gptq_with_bitblas prefer "quantized implementation" ZEROS_MODE = "quantized" - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool) -> None: + def __init__(self, weight_bits: int, group_size: Optional[int], + desc_act: Optional[bool], is_sym: Optional[bool], + quant_method: Optional[str]) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -51,6 +52,7 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, self.group_size = group_size self.desc_act = desc_act self.is_sym = is_sym + self.quant_method = quant_method # Verify if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS: @@ -75,11 +77,17 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, # Zeros type for the quantized weights. self.zeros_mode = self.ZEROS_MODE + # set input bits if bitnet + self.input_bits: Optional[int] = None + if self.quant_method == "bitnet": + self.input_bits = 8 def __repr__(self) -> str: return (f"BitBLASConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " - f"desc_act={self.desc_act})") + f"desc_act={self.desc_act}, " + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") @classmethod def get_name(cls) -> str: @@ -98,13 +106,24 @@ def get_min_capability(cls) -> int: def get_config_filenames(cls) -> List[str]: return ["quantize_config.json"] + @staticmethod + def get_from_keys(config: Dict[str, Any], + keys: List[str], + default: Any = None) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + return default + @classmethod def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": weight_bits = cls.get_from_keys(config, ["bits"]) - group_size = cls.get_from_keys(config, ["group_size"]) - desc_act = cls.get_from_keys(config, ["desc_act"]) - is_sym = cls.get_from_keys(config, ["sym"]) - return cls(weight_bits, group_size, desc_act, is_sym) + group_size = cls.get_from_keys(config, ["group_size"], -1) + desc_act = cls.get_from_keys(config, ["desc_act"], False) + is_sym = cls.get_from_keys(config, ["sym"], False) + quant_method = cls.get_from_keys(config, ["quant_method"]) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -141,7 +160,7 @@ class BitBLASLinearMethod(LinearMethodBase): Args: quant_config: The BitBLAS quantization config. """ - OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512] + OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024] ENABLE_TUNING = True BITBLAS_DTYPES = { torch.float32: "float32", @@ -152,8 +171,17 @@ class BitBLASLinearMethod(LinearMethodBase): def __init__(self, quant_config: BitBLASConfig): self.quant_config = quant_config + if self.quant_config.quant_method == "bitnet": + input_bits = self.quant_config.input_bits + if input_bits is None: + raise ValueError("input_bits must be set for bitnet") + self.Qp = 2**(input_bits - 1) - 1 + self.Qn = -2**(input_bits - 1) + else: + self.Qp = None + self.Qn = None - def create_weights( + def create_weights_bitnet( self, layer: torch.nn.Module, input_size_per_partition: int, @@ -193,12 +221,114 @@ def create_weights( # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul(input_size_per_partition, + output_size_per_partition, + params_dtype=torch.int8, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + out_dtype="float32") + + # Initialize quantized weights with dimensions + + qweight = Parameter( + torch.empty( + self.bitblas_matmul.retrieve_weight_shape(), + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + ), + requires_grad=False, + ) + # Attributes to help with unpacking and applying the weights later. + set_weight_attrs( + qweight, + { + "input_dim": + 1, + "output_dim": + 0, + "packed_dim": + 1, + "bitblas_tile_size": + (self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.bitblas_matmul.propagate_b else None), + "pack_factor": + self.quant_config.pack_factor, + "weight_propagation": + self.bitblas_matmul.propagate_b, + }, + ) + + sw = Parameter( + torch.empty( + 1, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + sw, + { + "input_dim": None, + "output_dim": None, + "ignore_warning": True, + }, + ) + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("sw", sw) + set_weight_attrs(sw, extra_weight_attrs) + + def create_weights_gptq( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing quantized + weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_size_per_partition: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the + input size per partition is not divisible by the group size in + `quant_config`. + """ + del input_size, output_size # Unused arguments. + if params_dtype != torch.float16: + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + group_size = self.quant_config.group_size + if group_size is None: + group_size = -1 + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if (group_size != -1 and input_size_per_partition % group_size != 0): raise ValueError( f"Input size per partition ({input_size_per_partition}) must " - f"be divisible by group size ({self.quant_config.group_size})." - ) + f"be divisible by group size ({group_size}).") # Initialize or retrieve the BitBLAS matrix multiplication operator. self._configure_bitblas_matmul( @@ -232,19 +362,18 @@ def create_weights( "packed_dim": 1, "bitblas_tile_size": - (self.bitblas_matmul.retrieve_weight_shape()[-2] if - self.bitblas_matmul.transform_weight is not None else None), + (self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.bitblas_matmul.propagate_b else None), "pack_factor": self.quant_config.pack_factor, "weight_propagation": - self.bitblas_matmul.transform_weight is not None, + self.bitblas_matmul.propagate_b, }, ) # Compute the number of input groups for channel-wise quantization. - input_groups = (1 if self.quant_config.group_size == -1 else - input_size_per_partition // - self.quant_config.group_size) + input_groups = (1 if group_size == -1 else input_size_per_partition // + group_size) # Initialize scales and zeros for the quantized weights. scales = Parameter( @@ -302,6 +431,31 @@ def create_weights( layer.register_parameter("zeros", zeros) set_weight_attrs(zeros, extra_weight_attrs) + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.quant_config.quant_method == "bitnet": + return self.create_weights_bitnet(layer, input_size_per_partition, + output_partition_sizes, + input_size, output_size, + params_dtype, + **extra_weight_attrs) + elif self.quant_config.quant_method == "gptq": + return self.create_weights_gptq(layer, input_size_per_partition, + output_partition_sizes, input_size, + output_size, params_dtype, + **extra_weight_attrs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + def _configure_bitblas_matmul( self, infeatures, @@ -311,28 +465,43 @@ def _configure_bitblas_matmul( bias, layout, bits, + out_dtype="float16", ): from bitblas import MatmulConfig bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] - W_dtype = f"uint{bits}" + if self.quant_config.quant_method == "gptq": + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + group_size = self.quant_config.group_size + zeros_mode = self.quant_config.zeros_mode + elif self.quant_config.quant_method == "bitnet": + with_scaling = False + with_zeros = False + W_dtype = f"int{bits}" + group_size = None + zeros_mode = None + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") matmul_config = MatmulConfig( - M=self.OPT_FEATURES, + # M=self.OPT_FEATURES, N=outfeatures, K=infeatures, A_dtype=bitblas_dtype, W_dtype=W_dtype, - out_dtype=bitblas_dtype, + out_dtype=out_dtype, accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, storage_dtype=self.quant_config.STORAGE_DTYPE, - with_scaling=True, - with_zeros=True, - group_size=self.quant_config.group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, with_bias=bias, layout=layout, - zeros_mode=self.quant_config.zeros_mode, + zeros_mode=zeros_mode, ) self.bitblas_matmul = self._get_or_create_bitblas_operator( matmul_config, enable_tuning) @@ -364,7 +533,15 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): logger.info(_message) return bitblas_matmul - def apply( + def activation_quant(self, x, num_bits=8): + x = x.float() + Qn = self.Qn + Qp = self.Qp + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) + return result.type(torch.int8) + + def apply_gptq( self, layer: torch.nn.Module, x: torch.Tensor, @@ -384,3 +561,43 @@ def apply( output.add_(bias) # In-place add return output + + def apply_bitnet( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + quant_input = self.activation_quant( + x, self.quant_config.input_bits).detach() + + fp32_out = self.bitblas_matmul(quant_input, layer.qweight) + sw = layer.sw + Qp = self.Qp + si = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + # if / (si * sw) will cause inf in some cases + output = fp32_out / si + output = output / sw + output = output.half() + output = output.type(x.dtype) + + output = output.view(x.shape[:-1] + (output.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output + + def apply( + self, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + if self.quant_config.quant_method == "bitnet": + return self.apply_bitnet(*args, **kwargs) + elif self.quant_config.quant_method == "gptq": + return self.apply_gptq(*args, **kwargs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 0423fcefc904..236dc4439381 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -156,6 +156,7 @@ class BITNETBitBLASLinearMethod(LinearMethodBase): def __init__(self, quant_config: BITNETBitBLASConfig) -> None: self.quant_config = quant_config self.Qp = 2**(quant_config.input_bits - 1) - 1 + self.Qn = -2**(quant_config.input_bits - 1) def create_weights( self, @@ -318,8 +319,8 @@ def weight_quant(self, weight): def activation_quant(self, x, num_bits=8): x = x.float() - Qn = -(2**(num_bits - 1)) - Qp = 2**(num_bits - 1) - 1 + Qn = self.Qn + Qp = self.Qp s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) return result.type(torch.int8) @@ -380,7 +381,7 @@ def replace_tensor(name, new_t): setattr(layer, name, new_t) # Repack weights - bitblas_qweight = self.repack_bitblas_from_bitnet(layer.weights) + bitblas_qweight = self.repack_bitblas_from_bitnet(layer.weight) # free the original weight tensor free_tensor("weight") replace_tensor("qweight", bitblas_qweight) @@ -394,6 +395,7 @@ def replace_tensor(name, new_t): output = output / sw output = output.half() output = output.type(x.dtype) + if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index 2c109b47a0ea..e0ede2ff3cc2 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -716,6 +716,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] weight_loader = param.weight_loader + # align scaling attr with param + if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.view(param.data.shape) weight_loader(param, loaded_weight, shard_id) break else: @@ -738,6 +741,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) + # align scaling attr with param + if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.view(param.data.shape) weight_loader(param, loaded_weight) # If this function is called, it should always initialize KV cache scale From 7f69aef7f2870149c71de75515b00f2928b9ddc2 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 19 Jul 2024 04:15:17 +0000 Subject: [PATCH 08/24] Improve performance for bitnet --- .../layers/quantization/bitblas.py | 28 +-- .../layers/quantization/bitnet_bitblas.py | 35 +++- vllm/model_executor/models/bitnet.py | 165 ++++-------------- 3 files changed, 84 insertions(+), 144 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 8af5f82f891d..487fa594b435 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -533,13 +533,21 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): logger.info(_message) return bitblas_matmul - def activation_quant(self, x, num_bits=8): + @torch.compile + def activation_quant(self, x): x = x.float() Qn = self.Qn Qp = self.Qp s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) - result = (x * s).round().clamp(Qn, Qp) - return result.type(torch.int8) + result = (x * s).round().clamp(Qn, Qp).type(torch.int8) + return result, s + + @torch.compile + def post_quant_process(self, input, si, sw): + out = input / si + out = out / sw + out = out.half() + return out def apply_gptq( self, @@ -569,18 +577,14 @@ def apply_bitnet( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - quant_input = self.activation_quant( - x, self.quant_config.input_bits).detach() + quant_input, si = self.activation_quant(x) + + output = self.bitblas_matmul(quant_input, layer.qweight) - fp32_out = self.bitblas_matmul(quant_input, layer.qweight) sw = layer.sw - Qp = self.Qp - si = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + # if / (si * sw) will cause inf in some cases - output = fp32_out / si - output = output / sw - output = output.half() - output = output.type(x.dtype) + output = self.post_quant_process(output, si, sw) output = output.view(x.shape[:-1] + (output.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 236dc4439381..93cc6dc0fe21 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -8,6 +8,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) @@ -327,7 +329,8 @@ def activation_quant(self, x, num_bits=8): def repack_bitblas_from_bitnet(self, b_q_weight: torch.Tensor, - is_qkv_packed: bool = False): + is_qkv_packed: bool = False, + is_gateup_packed: bool = False): if is_qkv_packed: hidden_size = b_q_weight.size(0) sw_q = 1 / b_q_weight[:hidden_size // @@ -336,9 +339,11 @@ def repack_bitblas_from_bitnet(self, 3].abs().mean().clamp(min=1e-5) sw_v = 1 / b_q_weight[2 * hidden_size // 3:].abs().mean().clamp(min=1e-5) - self.sw_q = sw_q - self.sw_k = sw_k - self.sw_v = sw_v + self.sw = torch.cat( + (sw_q.repeat(hidden_size // 3), sw_k.repeat( + hidden_size // 3), sw_v.repeat(hidden_size // 3)), + dim=0) + qweight_q = self.weight_quant(b_q_weight[:hidden_size // 3]).detach() qweight_k = self.weight_quant( @@ -346,6 +351,20 @@ def repack_bitblas_from_bitnet(self, qweight_v = self.weight_quant(b_q_weight[2 * hidden_size // 3:]).detach() qweight = torch.cat([qweight_q, qweight_k, qweight_v], dim=0) + elif is_gateup_packed: + hidden_size = b_q_weight.size(0) + sw_gate = 1 / b_q_weight[:hidden_size // + 2].abs().mean().clamp(min=1e-5) + sw_up = 1 / b_q_weight[hidden_size // + 2:].abs().mean().clamp(min=1e-5) + self.sw = torch.cat((sw_gate.repeat( + hidden_size // 2), sw_up.repeat(hidden_size // 2)), + dim=0) + qweight_gate = self.weight_quant(b_q_weight[:hidden_size // + 2]).detach() + qweight_up = self.weight_quant(b_q_weight[hidden_size // + 2:]).detach() + qweight = torch.cat([qweight_gate, qweight_up], dim=0) else: sw = 1 / b_q_weight.abs().mean().clamp(min=1e-5) self.sw = sw @@ -381,7 +400,13 @@ def replace_tensor(name, new_t): setattr(layer, name, new_t) # Repack weights - bitblas_qweight = self.repack_bitblas_from_bitnet(layer.weight) + # QKVParallelLinear is a special case where the weight is packed + # For bitnet as different weights matrix shouldn't share the same + # scale, we need to unpack and repack the weight matrix + is_qkv_packed = isinstance(layer, QKVParallelLinear) + is_gateup_packed = isinstance(layer, MergedColumnParallelLinear) + bitblas_qweight = self.repack_bitblas_from_bitnet( + layer.weight, is_qkv_packed, is_gateup_packed) # free the original weight tensor free_tensor("weight") replace_tensor("qweight", bitblas_qweight) diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index e0ede2ff3cc2..8bf0d23bd375 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -36,7 +36,9 @@ get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -255,8 +257,8 @@ def __init__( raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() - self.ffn_layernorm = BitnetRMSNorm(intermediate_size, - eps=config.rms_norm_eps) + self.ffn_layernorm = RMSNorm(intermediate_size, + eps=config.rms_norm_eps) def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -266,76 +268,6 @@ def forward(self, x): return x -class BitnetRotaryEmbedding(nn.Module): - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**(torch.arange( - 0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, - device=device, - dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", - emb.cos().to(torch.get_default_dtype()), - persistent=False) - self.register_buffer("_sin_cached", - emb.sin().to(torch.get_default_dtype()), - persistent=False) - - @property - def sin_cached(self): - logger.warning_once( - "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class" - ) - return self._sin_cached - - @property - def cos_cached(self): - logger.warning_once( - "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " - "the forward method of RoPE from now on instead. It is not used in the `BitnetAttention` class" - ) - return self._cos_cached - - @torch.no_grad() - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = (self.inv_freq[None, :, None].float().expand( - position_ids.shape[0], -1, 1)) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = (device_type if isinstance(device_type, str) - and device_type != "mps" else "cpu") - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() - @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class BitnetAttention(nn.Module): def __init__( @@ -378,23 +310,11 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.attention_dropout = config.attention_dropout - self.q_proj = RowParallelLinear( - input_size=hidden_size, - output_size=self.head_dim * self.num_heads, - bias=bias, - quant_config=quant_config, - ) - - self.k_proj = RowParallelLinear( - input_size=hidden_size, - output_size=self.head_dim * self.num_heads, - bias=bias, - quant_config=quant_config, - ) - - self.v_proj = RowParallelLinear( - input_size=hidden_size, - output_size=self.head_dim * self.num_heads, + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, ) @@ -423,8 +343,8 @@ def __init__( rope_scaling=rope_scaling, ) - self.inner_attn_ln = BitnetRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.inner_attn_ln = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def find_flash_attn_supported_head_dims(self, head_dim: int) -> int: """ @@ -450,9 +370,8 @@ def forward( ) -> torch.Tensor: # QKV projection cannot be grouped as the they # do not share the same scaling factor - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) # Padding as paged attention doesn't support head_dim == 100 q = torch.nn.functional.pad( @@ -477,25 +396,6 @@ def forward( return output -class BitnetRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """ - BitnetRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - class BitnetDecoderLayer(nn.Module): def __init__( @@ -538,10 +438,10 @@ def __init__( bias=getattr(config, "mlp_bias", False), config=config, ) - self.input_layernorm = BitnetRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = BitnetRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -549,23 +449,29 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) - hidden_states = residual + hidden_states + # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + + return hidden_states, residual class BitnetModel(nn.Module): @@ -592,7 +498,7 @@ def __init__( quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) - self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -609,15 +515,17 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) + residual = None for i in range(len(self.layers)): layer = self.layers[i] - hidden_states = layer( + hidden_states, residual = layer( positions, hidden_states, kv_caches[i], attn_metadata, + residual, ) - hidden_states = self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -695,6 +603,9 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] From a973123209a1e5c44bcda4455df9090c2026c5fc Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 19 Jul 2024 06:00:29 +0000 Subject: [PATCH 09/24] fix lm_head for gptq model refactor --- vllm/model_executor/models/bitnet.py | 81 +++++++++++++++++++--------- 1 file changed, 56 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index 8bf0d23bd375..52be3542df31 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -32,7 +32,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul @@ -50,8 +50,9 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput +from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip, print_warning_once +from .utils import PPMissingLayer logger = init_logger(__name__) @@ -509,12 +510,19 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -559,21 +567,25 @@ def __init__( self.config = config self.model = BitnetModel(config, cache_config, quant_config) - self.unpadded_vocab_size = config.vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) - self.sampler = Sampler() + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() def forward( self, @@ -581,14 +593,16 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + hidden_states = self.model( + input_ids, positions, kv_caches, attn_metadata, intermediate_tensors + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits @@ -600,6 +614,20 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -639,19 +667,22 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: print_warning_once( f"Found kv scale in the checkpoint (e.g. {name}), " "but not found the expected name in the model " f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded.") + "not loaded." + ) continue else: name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) # align scaling attr with param if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0: loaded_weight = loaded_weight.view(param.data.shape) From aea1f4c651f5d9dafd635030ae8c82923cbe6487 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 19 Jul 2024 06:32:14 +0000 Subject: [PATCH 10/24] linx fix --- vllm/model_executor/models/bitnet.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index 52be3542df31..cbf92fda4347 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -52,6 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip, print_warning_once + from .utils import PPMissingLayer logger = init_logger(__name__) @@ -582,7 +583,8 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + config.vocab_size, + logit_scale) self.sampler = Sampler() else: self.lm_head = PPMissingLayer() @@ -595,9 +597,8 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.model( - input_ids, positions, kv_caches, attn_metadata, intermediate_tensors - ) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -667,22 +668,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale" - ) + ".kv_scale", ".attn.kv_scale") if remapped_kv_scale_name not in params_dict: print_warning_once( f"Found kv scale in the checkpoint (e.g. {name}), " "but not found the expected name in the model " f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded." - ) + "not loaded.") continue else: name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) # align scaling attr with param if sum(param.data.shape) == 0 or sum(loaded_weight.shape) == 0: loaded_weight = loaded_weight.view(param.data.shape) From 17128d5c9043f27f7b5e452cbb2bb784a66f4f77 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 13 Aug 2024 12:15:01 +0000 Subject: [PATCH 11/24] handle compressed scale weight. --- .../layers/quantization/bitblas.py | 3 +-- .../model_loader/weight_utils.py | 8 +++++++ vllm/model_executor/models/bitnet.py | 21 ++++++++++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 487fa594b435..ec86ba96f42a 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -264,7 +264,7 @@ def create_weights_bitnet( sw = Parameter( torch.empty( - 1, + (output_size_per_partition,), device="cuda", dtype=params_dtype, ), @@ -488,7 +488,6 @@ def _configure_bitblas_matmul( f"Unsupported quant_method {self.quant_config.quant_method}") matmul_config = MatmulConfig( - # M=self.OPT_FEATURES, N=outfeatures, K=infeatures, A_dtype=bitblas_dtype, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 698c59d49fe0..adc3f9eaba50 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -435,6 +435,14 @@ def default_weight_loader(param: torch.Tensor, assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) +def partital_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor, id: int) -> None: + """Partition weight loader.""" + param_size = param.size().numel() + loaded_weight_size = loaded_weight.size().numel() + assert param_size % loaded_weight_size == 0 + assert id < (param_size // loaded_weight_size) + param.data[id*loaded_weight_size:(id+1)*loaded_weight_size].copy_(loaded_weight) def initialize_dummy_weights( model: torch.nn.Module, diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index cbf92fda4347..015ce4516d41 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader) + default_weight_loader, kv_cache_scales_loader, partital_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip, print_warning_once @@ -654,6 +654,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # This is a trick to load weight scale for BitNet Models + if name.endswith(".sw") and name in params_dict: + param = params_dict[name] + if param_name == ".qkv_proj": + loaded_weight = loaded_weight.repeat((param.data.shape[0] // 3, *param.data.shape[1:])) + shard_id = { + "q": 0, + "k": 1, + "v": 2 + }[shard_id] + elif param_name == ".gate_up_proj": + loaded_weight = loaded_weight.repeat((param.data.shape[0] // 2, *param.data.shape[1:])) + partital_weight_loader(param, loaded_weight, shard_id) + break param = params_dict[name] weight_loader = param.weight_loader # align scaling attr with param @@ -665,6 +679,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # This is a trick to load weight scale for BitNet Models + if name.endswith(".sw") and name in params_dict: + param = params_dict[name] + weight_loader(param, loaded_weight.repeat(param.data.shape)) + continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( From 1741ed49c5a13bbad42c4df6b60e605ad48a9f07 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 13 Aug 2024 12:15:08 +0000 Subject: [PATCH 12/24] lint fix --- .../layers/quantization/bitblas.py | 2 +- .../model_executor/model_loader/weight_utils.py | 9 ++++++--- vllm/model_executor/models/bitnet.py | 17 ++++++++--------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index ec86ba96f42a..debede2afafc 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -264,7 +264,7 @@ def create_weights_bitnet( sw = Parameter( torch.empty( - (output_size_per_partition,), + (output_size_per_partition, ), device="cuda", dtype=params_dtype, ), diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index adc3f9eaba50..ed560b0c4b5b 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -435,14 +435,17 @@ def default_weight_loader(param: torch.Tensor, assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) -def partital_weight_loader(param: torch.Tensor, - loaded_weight: torch.Tensor, id: int) -> None: + +def partital_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor, + id: int) -> None: """Partition weight loader.""" param_size = param.size().numel() loaded_weight_size = loaded_weight.size().numel() assert param_size % loaded_weight_size == 0 assert id < (param_size // loaded_weight_size) - param.data[id*loaded_weight_size:(id+1)*loaded_weight_size].copy_(loaded_weight) + param.data[id * loaded_weight_size:(id + 1) * + loaded_weight_size].copy_(loaded_weight) + def initialize_dummy_weights( model: torch.nn.Module, diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index 015ce4516d41..0a070413ff5f 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -657,15 +657,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # This is a trick to load weight scale for BitNet Models if name.endswith(".sw") and name in params_dict: param = params_dict[name] - if param_name == ".qkv_proj": - loaded_weight = loaded_weight.repeat((param.data.shape[0] // 3, *param.data.shape[1:])) - shard_id = { - "q": 0, - "k": 1, - "v": 2 - }[shard_id] + if param_name == ".qkv_proj": + loaded_weight = loaded_weight.repeat( + (param.data.shape[0] // 3, *param.data.shape[1:])) + shard_id = {"q": 0, "k": 1, "v": 2}[shard_id] elif param_name == ".gate_up_proj": - loaded_weight = loaded_weight.repeat((param.data.shape[0] // 2, *param.data.shape[1:])) + loaded_weight = loaded_weight.repeat( + (param.data.shape[0] // 2, *param.data.shape[1:])) partital_weight_loader(param, loaded_weight, shard_id) break param = params_dict[name] @@ -682,7 +680,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # This is a trick to load weight scale for BitNet Models if name.endswith(".sw") and name in params_dict: param = params_dict[name] - weight_loader(param, loaded_weight.repeat(param.data.shape)) + weight_loader(param, + loaded_weight.repeat(param.data.shape)) continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): From 726a1f791c2c6b2dd6897f9cbfab994b654c8938 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Aug 2024 05:27:31 +0000 Subject: [PATCH 13/24] remove partial weight load for sw --- .../model_loader/weight_utils.py | 11 ---------- vllm/model_executor/models/bitnet.py | 21 +------------------ 2 files changed, 1 insertion(+), 31 deletions(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index ed560b0c4b5b..698c59d49fe0 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -436,17 +436,6 @@ def default_weight_loader(param: torch.Tensor, param.data.copy_(loaded_weight) -def partital_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor, - id: int) -> None: - """Partition weight loader.""" - param_size = param.size().numel() - loaded_weight_size = loaded_weight.size().numel() - assert param_size % loaded_weight_size == 0 - assert id < (param_size // loaded_weight_size) - param.data[id * loaded_weight_size:(id + 1) * - loaded_weight_size].copy_(loaded_weight) - - def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3, diff --git a/vllm/model_executor/models/bitnet.py b/vllm/model_executor/models/bitnet.py index 0a070413ff5f..a2dd6fb3aa97 100644 --- a/vllm/model_executor/models/bitnet.py +++ b/vllm/model_executor/models/bitnet.py @@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, partital_weight_loader) + default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip, print_warning_once @@ -654,18 +654,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # This is a trick to load weight scale for BitNet Models - if name.endswith(".sw") and name in params_dict: - param = params_dict[name] - if param_name == ".qkv_proj": - loaded_weight = loaded_weight.repeat( - (param.data.shape[0] // 3, *param.data.shape[1:])) - shard_id = {"q": 0, "k": 1, "v": 2}[shard_id] - elif param_name == ".gate_up_proj": - loaded_weight = loaded_weight.repeat( - (param.data.shape[0] // 2, *param.data.shape[1:])) - partital_weight_loader(param, loaded_weight, shard_id) - break param = params_dict[name] weight_loader = param.weight_loader # align scaling attr with param @@ -677,13 +665,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # This is a trick to load weight scale for BitNet Models - if name.endswith(".sw") and name in params_dict: - param = params_dict[name] - weight_loader(param, - loaded_weight.repeat(param.data.shape)) - continue - # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") From 68c8052434cb7e0d04c04e754b6b441a37721396 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Aug 2024 05:32:47 +0000 Subject: [PATCH 14/24] apply torch compile for uncompressed weight. --- .../layers/quantization/bitblas.py | 1 - .../layers/quantization/bitnet_bitblas.py | 28 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index debede2afafc..8b6ac41347c3 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -582,7 +582,6 @@ def apply_bitnet( sw = layer.sw - # if / (si * sw) will cause inf in some cases output = self.post_quant_process(output, si, sw) output = output.view(x.shape[:-1] + (output.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 93cc6dc0fe21..0531670d19d1 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -319,13 +319,21 @@ def weight_quant(self, weight): result = (weight * s).round().clamp(-1, 1) return result.type(torch.int8) - def activation_quant(self, x, num_bits=8): + @torch.compile + def activation_quant(self, x): x = x.float() Qn = self.Qn Qp = self.Qp s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) - result = (x * s).round().clamp(Qn, Qp) - return result.type(torch.int8) + result = (x * s).round().clamp(Qn, Qp).type(torch.int8) + return result, s + + @torch.compile + def post_quant_process(self, input, si, sw): + out = input / si + out = out / sw + out = out.half() + return out def repack_bitblas_from_bitnet(self, b_q_weight: torch.Tensor, @@ -381,8 +389,7 @@ def apply( part_size_n = layer.output_size_per_partition out_shape = x.shape[:-1] + (part_size_n, ) - quant_input = self.activation_quant( - x, self.quant_config.input_bits).detach() + quant_input, si = self.activation_quant(x) if layer.bitblas_state == BITNETBitBLASState.REPACK: layer.bitblas_state = BITNETBitBLASState.READY @@ -411,15 +418,8 @@ def replace_tensor(name, new_t): free_tensor("weight") replace_tensor("qweight", bitblas_qweight) - fp32_out = self.bitblas_matmul(quant_input, layer.qweight) - sw = self.sw - Qp = self.Qp - si = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) - # if / (si * sw) it will inf in some cases - output = fp32_out / si - output = output / sw - output = output.half() - output = output.type(x.dtype) + output = self.bitblas_matmul(quant_input, layer.qweight) + output = self.post_quant_process(output, si, self.sw) if bias is not None: output.add_(bias) # In-place add From 52418ef13bb0ac84769f3a6dee28bd0ecd394c55 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Aug 2024 07:06:37 +0000 Subject: [PATCH 15/24] merge bug fix --- vllm/model_executor/layers/quantization/bitblas.py | 3 ++- vllm/model_executor/layers/quantization/bitnet_bitblas.py | 3 ++- vllm/model_executor/layers/quantization/gptq_bitblas.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 8b6ac41347c3..d07c3c734c74 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -145,7 +145,8 @@ def override_quantization_method(cls, hf_quant_cfg, return None def get_quant_method( - self, layer: torch.nn.Module) -> Optional["BitBLASLinearMethod"]: + self, layer: torch.nn.Module, + prefix: str) -> Optional["BitBLASLinearMethod"]: if isinstance(layer, LinearBase): return BitBLASLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 0531670d19d1..492f24206c17 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -107,7 +107,8 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method( self, - layer: torch.nn.Module) -> Optional["BITNETBitBLASLinearMethod"]: + layer: torch.nn.Module, + prefix: str) -> Optional["BITNETBitBLASLinearMethod"]: if isinstance(layer, LinearBase): return BITNETBitBLASLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index a54f839a7e52..63f2f1a2ad36 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -150,7 +150,8 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method( self, - layer: torch.nn.Module) -> Optional["GPTQBitBLASLinearMethod"]: + layer: torch.nn.Module, + prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: if isinstance(layer, LinearBase): return GPTQBitBLASLinearMethod(self) return None From a15ba12f1c15b4019fa62ed00c3f73019f61e775 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Aug 2024 07:27:27 +0000 Subject: [PATCH 16/24] lint fix --- vllm/config.py | 4 ++-- vllm/model_executor/layers/quantization/bitblas.py | 5 ++--- vllm/model_executor/layers/quantization/bitnet_bitblas.py | 6 ++---- vllm/model_executor/layers/quantization/gptq_bitblas.py | 6 ++---- 4 files changed, 8 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 0f19eed80976..7d80a3509721 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -249,8 +249,8 @@ def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] optimized_quantization_methods = [ - "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "gptq_bitblas", "bitblas", "bitnet_bitblas", "compressed_tensors" + "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "gptq_bitblas", + "bitblas", "bitnet_bitblas", "compressed_tensors" ] tpu_supported_quantization = ["tpu_int8"] if self.quantization is not None: diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index d07c3c734c74..61c9013b0efe 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -144,9 +144,8 @@ def override_quantization_method(cls, hf_quant_cfg, return None - def get_quant_method( - self, layer: torch.nn.Module, - prefix: str) -> Optional["BitBLASLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BitBLASLinearMethod"]: if isinstance(layer, LinearBase): return BitBLASLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 492f24206c17..04625a536df8 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -105,10 +105,8 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method( - self, - layer: torch.nn.Module, - prefix: str) -> Optional["BITNETBitBLASLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BITNETBitBLASLinearMethod"]: if isinstance(layer, LinearBase): return BITNETBitBLASLinearMethod(self) return None diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 63f2f1a2ad36..3a5c7b45e764 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -148,10 +148,8 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method( - self, - layer: torch.nn.Module, - prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: if isinstance(layer, LinearBase): return GPTQBitBLASLinearMethod(self) return None From 53babaed83d07b6e9f51d662997f8ac26ead615f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 18 Aug 2024 13:07:54 +0000 Subject: [PATCH 17/24] fix torch compile issue --- vllm/model_executor/layers/quantization/bitblas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 61c9013b0efe..a58053f45598 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -541,10 +541,10 @@ def activation_quant(self, x): result = (x * s).round().clamp(Qn, Qp).type(torch.int8) return result, s - @torch.compile + @torch.compile(dynamic=True) def post_quant_process(self, input, si, sw): out = input / si - out = out / sw + out.div_(sw) out = out.half() return out From 40a4e53ac6c009d5c6e1b8e08f1f88bac7493ff8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 20 Aug 2024 13:08:12 +0000 Subject: [PATCH 18/24] bug fix. --- vllm/model_executor/layers/quantization/__init__.py | 2 +- vllm/model_executor/layers/quantization/bitblas.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index de087ad7dae6..9a821b7cd712 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -38,12 +38,12 @@ "fbgemm_fp8": FBGEMMFp8Config, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) - "bitnet_bitblas": BITNETBitBLASConfig, "marlin": MarlinConfig, "bitblas": BitBLASConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, "gptq_bitblas": GPTQBitBLASConfig, + "bitnet_bitblas": BITNETBitBLASConfig, "gguf": GGUFConfig, "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index a58053f45598..bf47ce7a6a73 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -475,6 +475,9 @@ def _configure_bitblas_matmul( with_scaling = True with_zeros = True W_dtype = f"uint{bits}" + if self.quant_config.is_sym: + with_zeros = False + W_dtype = f"int{bits}" group_size = self.quant_config.group_size zeros_mode = self.quant_config.zeros_mode elif self.quant_config.quant_method == "bitnet": @@ -560,7 +563,10 @@ def apply_gptq( x_2d = x.view(-1, x.shape[-1]) - output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) + if self.quant_config.is_sym: + output_2d = self.bitblas_matmul(x_2d, qweight, scales) + else: + output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) From d316a87241334f310f7ef35d4f05e26724ffb624 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 20 Aug 2024 14:08:32 +0000 Subject: [PATCH 19/24] BENCHMARK SCRIPTS --- benchmarks/kernels/benchmark_bitblas.py | 675 ++++++++++++++++++ .../layers/quantization/bitblas.py | 3 +- 2 files changed, 676 insertions(+), 2 deletions(-) create mode 100644 benchmarks/kernels/benchmark_bitblas.py diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py new file mode 100644 index 000000000000..410355a42120 --- /dev/null +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -0,0 +1,675 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from vllm.model_executor.layers.quantization.bitblas import ( + BITBLAS_TARGET, + Matmul, + MatmulConfig, +) +from vllm.utils import FlexibleArgumentParser + +parser = FlexibleArgumentParser( + description="Benchmark BitBLAS int4 on a specific target." +) + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=BITBLAS_TARGET, + help="Specify the target device for benchmarking.", +) +parser.add_argument( + "--group_size", type=int, default=None, + help="Group size for grouped quantization." +) +parser.add_argument( + "--A_dtype", + type=str, + default="float16", + choices=["float16", "float32", "float64", "int32", "int8"], + help="Data type of activation A.", +) +parser.add_argument( + "--W_dtype", + type=str, + default="int4", + choices=[ + "float16", + "float32", + "float64", + "int32", + "int8", + "int4", + "int2", + "int1", + "nf4", + "fp4_e2m1", + ], + help="Data type of weight W.", +) +parser.add_argument( + "--accum_dtype", + type=str, + default="float16", + choices=["float16", "int32"], + help="Data type for accumulation.", +) +parser.add_argument( + "--out_dtype", + type=str, + default="float16", + choices=["float16", "float32", "int32", "int8"], + help="Data type for output.", +) +parser.add_argument( + "--layout", + type=str, + default="nt", + choices=["nt", "nn"], + help="Matrix layout, 'nt' for non-transpose A and transpose W.", +) +parser.add_argument( + "--with_bias", action="store_true", help="Include bias in the benchmark." +) +parser.add_argument( + "--with_scaling", + action="store_true", + help="Include scaling factor in the quantization.", +) +parser.add_argument( + "--with_zeros", action="store_true", + help="Include zeros in the quantization." +) +parser.add_argument( + "--zeros_mode", + type=str, + default=None, + choices=["original", "rescale", "quantized"], + help="Specify the mode for calculating zeros.", +) + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +group_size = args.group_size +A_dtype = args.A_dtype +W_dtype = args.W_dtype +accum_dtype = args.accum_dtype +out_dtype = args.out_dtype +layout = args.layout +with_bias = args.with_bias +group_size = args.group_size +with_scaling = args.with_scaling +with_zeros = args.with_zeros +zeros_mode = args.zeros_mode + +test_shapes = [ + # square test + ( + MatmulConfig, + Matmul, + ( + 1, + 16384, + 16384, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + # BLOOM-176B + ( + MatmulConfig, + Matmul, + ( + 1, + 43008, + 14336, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 14336, + 14336, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 57344, + 14336, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 14336, + 57344, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + # # OPT-65B + ( + MatmulConfig, + Matmul, + ( + 1, + 9216, + 9216, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 36864, + 9216, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 9216, + 36864, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 22016, + 8192, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + # # LLAMA-70B/65B + ( + MatmulConfig, + Matmul, + ( + 1, + 8192, + 22016, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 8192, + 8192, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 28672, + 8192, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 1, + 8192, + 28672, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + # square test + ( + MatmulConfig, + Matmul, + ( + 16384, + 16384, + 16384, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + # BLOOM-176B + ( + MatmulConfig, + Matmul, + ( + 8192, + 43008, + 14336, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 14336, + 14336, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 57344, + 14336, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 14336, + 57344, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + # # OPT-65B + ( + MatmulConfig, + Matmul, + ( + 8192, + 9216, + 9216, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 36864, + 9216, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 9216, + 36864, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 22016, + 8192, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + # # LLAMA-70B/65B + ( + MatmulConfig, + Matmul, + ( + 8192, + 8192, + 22016, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 8192, + 8192, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 28672, + 8192, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), + ( + MatmulConfig, + Matmul, + ( + 8192, + 8192, + 28672, + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + ), + ), +] + +benchmark_sets = [] +benchmark_sets.extend(test_shapes) + +# fmt:on + +benchmark_results = {} +for config, operator, input_args in benchmark_sets: + config = config(*input_args) + matmul = operator(config, target=target, enable_tuning=True) + kernel_latency = matmul.profile_latency() + + print("Time cost is: {:.3f} ms".format(kernel_latency)) + + profile_config = { + f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { + "BitBLAS_top20_latency": kernel_latency, + } + } + + benchmark_results.update(profile_config) + +# Define headers for the table +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Latency", +] + +col_widths = [0, 0, 0] +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), + col_widths[0]) + col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, + col_widths[1])) + col_widths[2] = max( + max(len(str(headers[2])), + len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, + col_widths[2], + ) + break + +for i, header in enumerate(headers): + headers[i] = header.ljust(col_widths[i]) + +print("".join(headers)) + +print("-" * sum(col_widths)) + +for config, values in benchmark_results.items(): + args = config.split("-") + func_name = args[0] + input_args = "-".join(args[1:]) + row = [ + func_name, + input_args, + f"{values['BitBLAS_top20_latency']:.3f} ms", + ] + print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]) + + "\n") diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index bf47ce7a6a73..1e5ba8f8fbf5 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -13,6 +13,7 @@ try: import bitblas + from bitblas import Matmul, MatmulConfig from bitblas.utils import auto_detect_nvidia_target except ImportError as e: bitblas_import_exception = e @@ -467,7 +468,6 @@ def _configure_bitblas_matmul( bits, out_dtype="float16", ): - from bitblas import MatmulConfig bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] @@ -509,7 +509,6 @@ def _configure_bitblas_matmul( matmul_config, enable_tuning) def _get_or_create_bitblas_operator(self, config, enable_tuning): - from bitblas import Matmul from bitblas.cache import global_operator_cache if global_operator_cache.size() == 0: From bffc05b2076da2a7ce7a0d12018381a96d06ece6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 20 Aug 2024 16:51:57 +0000 Subject: [PATCH 20/24] Implement Test --- benchmarks/kernels/benchmark_bitblas.py | 44 ++++++------- tests/models/test_bitblas.py | 65 +++++++++++++++++++ tests/models/test_bitnet.py | 65 +++++++++++++++++++ tests/models/test_gptq_bitblas.py | 63 ++++++++++++++++++ vllm/config.py | 15 ++++- .../layers/quantization/gptq_bitblas.py | 26 +++++++- 6 files changed, 248 insertions(+), 30 deletions(-) create mode 100644 tests/models/test_bitblas.py create mode 100644 tests/models/test_bitnet.py create mode 100644 tests/models/test_gptq_bitblas.py diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py index 410355a42120..f547ad3762a5 100644 --- a/benchmarks/kernels/benchmark_bitblas.py +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -1,16 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from vllm.model_executor.layers.quantization.bitblas import ( - BITBLAS_TARGET, - Matmul, - MatmulConfig, -) +from vllm.model_executor.layers.quantization.bitblas import (BITBLAS_TARGET, + Matmul, + MatmulConfig) from vllm.utils import FlexibleArgumentParser parser = FlexibleArgumentParser( - description="Benchmark BitBLAS int4 on a specific target." -) + description="Benchmark BitBLAS int4 on a specific target.") # Add arguments to the parser parser.add_argument( @@ -19,10 +16,10 @@ default=BITBLAS_TARGET, help="Specify the target device for benchmarking.", ) -parser.add_argument( - "--group_size", type=int, default=None, - help="Group size for grouped quantization." -) +parser.add_argument("--group_size", + type=int, + default=None, + help="Group size for grouped quantization.") parser.add_argument( "--A_dtype", type=str, @@ -69,18 +66,17 @@ choices=["nt", "nn"], help="Matrix layout, 'nt' for non-transpose A and transpose W.", ) -parser.add_argument( - "--with_bias", action="store_true", help="Include bias in the benchmark." -) +parser.add_argument("--with_bias", + action="store_true", + help="Include bias in the benchmark.") parser.add_argument( "--with_scaling", action="store_true", help="Include scaling factor in the quantization.", ) -parser.add_argument( - "--with_zeros", action="store_true", - help="Include zeros in the quantization." -) +parser.add_argument("--with_zeros", + action="store_true", + help="Include zeros in the quantization.") parser.add_argument( "--zeros_mode", type=str, @@ -644,12 +640,12 @@ args = config.split("-") func_name = args[0] input_args = "-".join(args[1:]) - col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), + col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0]) - col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, - col_widths[1])) + col_widths[1] = max( + (max(len(str(headers[1])), len(input_args)) + 2, col_widths[1])) col_widths[2] = max( - max(len(str(headers[2])), + max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, col_widths[2], ) @@ -671,5 +667,5 @@ input_args, f"{values['BitBLAS_top20_latency']:.3f} ms", ] - print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]) - + "\n") + print("".join([str(i).ljust(col_widths[j]) + for j, i in enumerate(row)]) + "\n") diff --git a/tests/models/test_bitblas.py b/tests/models/test_bitblas.py new file mode 100644 index 000000000000..0ea2a17ba072 --- /dev/null +++ b/tests/models/test_bitblas.py @@ -0,0 +1,65 @@ +"""Compare the outputs of a GPTQ model to a bitblas model. + +Note: GPTQ and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/GPTQ models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_bitblas: str + model_gptq: str + + +model_pairs = [ + ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas", + model_gptq="hxbgsyxh/opt-125m-4bit-128g"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, + reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_bitblas, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=bitblas_outputs, + name_0="gptq", + name_1="bitblas", + ) diff --git a/tests/models/test_bitnet.py b/tests/models/test_bitnet.py new file mode 100644 index 000000000000..0cdb617f028c --- /dev/null +++ b/tests/models/test_bitnet.py @@ -0,0 +1,65 @@ +"""Compare the outputs of a bitnet model to a bitblas model. + +Note: bitnet and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/bitnet models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_bitblas: str + model_bitnet: str + + +model_pairs = [ + ModelPair(model_bitblas="hxbgsyxh/bitnet_b1_58-3B_bitblas", + model_bitnet="hxbgsyxh/bitnet_b1_58-3B"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, + reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_bitblas, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_bitnet, dtype=dtype, + quantization="bitnet") as bitnet_model: + bitnet_outputs = bitnet_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=bitnet_outputs, + outputs_1_lst=bitblas_outputs, + name_0="bitnet", + name_1="bitblas", + ) diff --git a/tests/models/test_gptq_bitblas.py b/tests/models/test_gptq_bitblas.py new file mode 100644 index 000000000000..3c4c08947dc6 --- /dev/null +++ b/tests/models/test_gptq_bitblas.py @@ -0,0 +1,63 @@ +"""Compare the outputs of a GPTQ model to a bitblas model. + +Note: GPTQ and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/GPTQ models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_gptq: str + + +model_pairs = [ + ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, + reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=bitblas_outputs, + name_0="gptq", + name_1="gptq_bitblas", + ) diff --git a/vllm/config.py b/vllm/config.py index f9c58838aa60..88345ab90933 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -267,9 +267,18 @@ def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] optimized_quantization_methods = [ - "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", - "gptq_bitblas", "bitblas", "bitnet_bitblas", "fbgemm_fp8", - "compressed_tensors", "compressed-tensors", "experts_int8", + "fp8", + "marlin", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "gptq_bitblas", + "bitblas", + "bitnet_bitblas", + "fbgemm_fp8", + "compressed_tensors", + "compressed-tensors", + "experts_int8", ] tpu_supported_quantization = ["tpu_int8"] if self.quantization is not None: diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 3a5c7b45e764..0dfd9d638944 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -32,7 +32,7 @@ GPTQ_BITBLAS_SUPPORTED_SYM = [False, True] -def unpack_qzeros(qzeros, bits) -> torch.Tensor: +def unpack_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor: qzeros = qzeros.view(torch.int32) elems_per_int32 = 32 // bits unpacked_zeros = torch.zeros( @@ -46,8 +46,26 @@ def unpack_qzeros(qzeros, bits) -> torch.Tensor: i = col % elems_per_int32 unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF + if not is_gptq_v2: + return unpacked_zeros + 1 + return unpacked_zeros - return unpacked_zeros + 1 + +def unpack_qweight(qweight, bits): + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> + (bits * i)) + + return torch.bitwise_and(unpacked_weight, 2**bits - 1) class GPTQBitBLASConfig(QuantizationConfig): @@ -432,9 +450,11 @@ def repack_bitblas_from_gptq(self, b_q_weight: torch.Tensor, # (outfeatures, infeatures), should be transposed. qweight = b_q_weight.T.contiguous().view( self.quant_config.TORCH_BITBLAS_STORAGE_DTYPE) + intweight = unpack_qweight(qweight, + self.quant_config.weight_bits).contiguous() if self.bitblas_matmul.weight_transform is not None: qweight = self.bitblas_matmul.weight_transform( - qweight.cpu()).cuda() + intweight.cpu()).cuda() # scales in gptq old quant linear stored with # (infeatures // group_size, outfeatures), should be transposed. scales = scales.T.contiguous() From 8b0972b924a38ccb342bde74e4c30de8bdaec242 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 20 Aug 2024 16:57:23 +0000 Subject: [PATCH 21/24] lint fix --- tests/models/test_bitblas.py | 9 +++------ tests/models/test_bitnet.py | 12 +++++------- tests/models/test_gptq_bitblas.py | 9 +++------ 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/models/test_bitblas.py b/tests/models/test_bitblas.py index 0ea2a17ba072..a8b3c2fb149b 100644 --- a/tests/models/test_bitblas.py +++ b/tests/models/test_bitblas.py @@ -5,8 +5,8 @@ bitblas/GPTQ models are in the top 3 selections of each other. Note: bitblas internally uses locks to synchronize the threads. This can -result in very slight nondeterminism for bitblas. As a result, we re-run the test -up to 3 times to see if we pass. +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. Run `pytest tests/models/test_bitblas.py`. """ @@ -14,8 +14,6 @@ import pytest -from tests.quantization.utils import is_quant_method_supported - from .utils import check_logprobs_close @@ -32,8 +30,7 @@ class ModelPair: @pytest.mark.flaky(reruns=2) -@pytest.mark.skipif(True, - reason="BitBLAS takes too much time for tuning.") +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/tests/models/test_bitnet.py b/tests/models/test_bitnet.py index 0cdb617f028c..e751bcc905dd 100644 --- a/tests/models/test_bitnet.py +++ b/tests/models/test_bitnet.py @@ -5,8 +5,8 @@ bitblas/bitnet models are in the top 3 selections of each other. Note: bitblas internally uses locks to synchronize the threads. This can -result in very slight nondeterminism for bitblas. As a result, we re-run the test -up to 3 times to see if we pass. +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. Run `pytest tests/models/test_bitblas.py`. """ @@ -14,8 +14,6 @@ import pytest -from tests.quantization.utils import is_quant_method_supported - from .utils import check_logprobs_close @@ -32,8 +30,7 @@ class ModelPair: @pytest.mark.flaky(reruns=2) -@pytest.mark.skipif(True, - reason="BitBLAS takes too much time for tuning.") +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -52,7 +49,8 @@ def test_models( bitblas_outputs = bitblas_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - with vllm_runner(model_pair.model_bitnet, dtype=dtype, + with vllm_runner(model_pair.model_bitnet, + dtype=dtype, quantization="bitnet") as bitnet_model: bitnet_outputs = bitnet_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/test_gptq_bitblas.py b/tests/models/test_gptq_bitblas.py index 3c4c08947dc6..03ffa4eca385 100644 --- a/tests/models/test_gptq_bitblas.py +++ b/tests/models/test_gptq_bitblas.py @@ -5,8 +5,8 @@ bitblas/GPTQ models are in the top 3 selections of each other. Note: bitblas internally uses locks to synchronize the threads. This can -result in very slight nondeterminism for bitblas. As a result, we re-run the test -up to 3 times to see if we pass. +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. Run `pytest tests/models/test_bitblas.py`. """ @@ -14,8 +14,6 @@ import pytest -from tests.quantization.utils import is_quant_method_supported - from .utils import check_logprobs_close @@ -30,8 +28,7 @@ class ModelPair: @pytest.mark.flaky(reruns=2) -@pytest.mark.skipif(True, - reason="BitBLAS takes too much time for tuning.") +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) From 8e1a7e89e36b9bf051efc4caa3faa114779c5679 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 20 Aug 2024 17:21:23 +0000 Subject: [PATCH 22/24] install bitblas by default to pass the doc gen. --- requirements-common.txt | 1 + requirements-test.txt | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index b6bed8a73d8c..95d81760606e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -26,3 +26,4 @@ librosa # Required for audio processing soundfile # Required for audio processing gguf == 0.9.1 importlib_metadata +bitblas==0.0.1.dev14 # Require for bitblas kernel diff --git a/requirements-test.txt b/requirements-test.txt index cdbc3e50cc9e..1ce0be9042b5 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -30,4 +30,5 @@ aiohttp # quantization bitsandbytes==0.42.0 -buildkite-test-collector==0.1.8 \ No newline at end of file +buildkite-test-collector==0.1.8 +bitblas==0.0.1.dev14 From 7fbbccf9349a35ca0c87fb6bb7e801be11daa6e7 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 20 Aug 2024 17:36:00 +0000 Subject: [PATCH 23/24] hide the bitblas import --- benchmarks/kernels/benchmark_bitblas.py | 7 +++--- .../layers/quantization/bitblas.py | 22 +++++-------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py index f547ad3762a5..c454688f1e68 100644 --- a/benchmarks/kernels/benchmark_bitblas.py +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -1,9 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from vllm.model_executor.layers.quantization.bitblas import (BITBLAS_TARGET, - Matmul, - MatmulConfig) +from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target + from vllm.utils import FlexibleArgumentParser parser = FlexibleArgumentParser( @@ -13,7 +12,7 @@ parser.add_argument( "--target", type=str, - default=BITBLAS_TARGET, + default=auto_detect_nvidia_target(), help="Specify the target device for benchmarking.", ) parser.add_argument("--group_size", diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 1e5ba8f8fbf5..4c991f80e9b1 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -11,20 +11,6 @@ logger = init_logger(__name__) -try: - import bitblas - from bitblas import Matmul, MatmulConfig - from bitblas.utils import auto_detect_nvidia_target -except ImportError as e: - bitblas_import_exception = e - raise ValueError( - "Trying to use the bitblas backend, but could not import dependencies" - f"with the following error: {bitblas_import_exception}" - ) from bitblas_import_exception - -BITBLAS_TARGET = auto_detect_nvidia_target() -BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() - BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] BITBLAS_SUPPORTED_SYM = [False, True] @@ -468,7 +454,7 @@ def _configure_bitblas_matmul( bits, out_dtype="float16", ): - + from bitblas import MatmulConfig bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] if self.quant_config.quant_method == "gptq": @@ -509,8 +495,10 @@ def _configure_bitblas_matmul( matmul_config, enable_tuning) def _get_or_create_bitblas_operator(self, config, enable_tuning): - from bitblas.cache import global_operator_cache - + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) From c487e69b29a9934c3463188953f9c38d7ca14657 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 20 Aug 2024 17:44:28 +0000 Subject: [PATCH 24/24] import fix --- .../layers/quantization/bitnet_bitblas.py | 21 +++++-------------- .../layers/quantization/gptq_bitblas.py | 21 ++++--------------- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitnet_bitblas.py b/vllm/model_executor/layers/quantization/bitnet_bitblas.py index 04625a536df8..597eac3eef54 100644 --- a/vllm/model_executor/layers/quantization/bitnet_bitblas.py +++ b/vllm/model_executor/layers/quantization/bitnet_bitblas.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional -import bitblas.cache import torch from torch.nn.parameter import Parameter @@ -16,20 +15,6 @@ logger = init_logger(__name__) -try: - import bitblas - from bitblas.utils import auto_detect_nvidia_target -except ImportError as e: - bitblas_import_exception = e - error_message = ( - "Trying to use the bitblas backend, but could not import dependencies " - f"with the following error: {bitblas_import_exception}") - raise ValueError(error_message) from bitblas_import_exception - -bitblas.set_log_level("Debug") -BITBLAS_TARGET = auto_detect_nvidia_target() -BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() - BITNET_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] @@ -285,7 +270,11 @@ def _configure_bitblas_matmul( def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul - from bitblas.cache import global_operator_cache + from bitblas.cache import get_database_path, global_operator_cache + from bitblas.utils import auto_detect_nvidia_target + + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 0dfd9d638944..22c6f2062e92 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional -import bitblas.cache import torch from torch.nn.parameter import Parameter @@ -14,20 +13,6 @@ logger = init_logger(__name__) -try: - import bitblas - from bitblas.utils import auto_detect_nvidia_target -except ImportError as e: - bitblas_import_exception = e - raise ValueError( - "Trying to use the bitblas backend, but could not import dependencies" - f"with the following error: {bitblas_import_exception}" - ) from bitblas_import_exception - -bitblas.set_log_level("Debug") -BITBLAS_TARGET = auto_detect_nvidia_target() -BITBLAS_DATABASE_PATH = bitblas.cache.get_database_path() - GPTQ_BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] GPTQ_BITBLAS_SUPPORTED_SYM = [False, True] @@ -415,8 +400,10 @@ def _configure_bitblas_matmul( matmul_config, enable_tuning) def _get_or_create_bitblas_operator(self, config, enable_tuning): - from bitblas import Matmul - from bitblas.cache import global_operator_cache + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,