From 5b278486b27d62477318037f4e6b5bc8e88b1506 Mon Sep 17 00:00:00 2001 From: "dongmao.zhang" Date: Fri, 21 Jun 2024 19:43:14 +0000 Subject: [PATCH] [bitsandbytes]: support read bnb pre-quantized model such as lllyasviel/omost-llama-3-8b-4bits --- docs/source/index.rst | 1 + docs/source/quantization/bnb.rst | 44 ++++++++++ tests/quantization/test_bitsandbytes.py | 16 +++- vllm/config.py | 1 + vllm/engine/arg_utils.py | 4 +- .../layers/quantization/bitsandbytes.py | 25 +----- vllm/model_executor/model_loader/loader.py | 88 ++++++++++++++++--- .../model_loader/weight_utils.py | 1 + 8 files changed, 142 insertions(+), 38 deletions(-) create mode 100644 docs/source/quantization/bnb.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 05133eb6d867a..95a10294bf1f6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -102,6 +102,7 @@ Documentation quantization/supported_hardware quantization/auto_awq + quantization/bnb quantization/fp8 quantization/fp8_e5m2_kvcache quantization/fp8_e4m3_kvcache diff --git a/docs/source/quantization/bnb.rst b/docs/source/quantization/bnb.rst new file mode 100644 index 0000000000000..a5900977fbae4 --- /dev/null +++ b/docs/source/quantization/bnb.rst @@ -0,0 +1,44 @@ +.. _bits_and_bytes: + +BitsAndBytes +================== + +vLLM now supports `BitsAndBytes `_ for more efficient model inference. +BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy. +Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data. + +Below are the steps to utilize BitsAndBytes with vLLM. + +.. code-block:: console + + $ pip install bitsandbytes>=0.42.0 + +vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. + +Read quantized checkpoint. + +You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes. +And usually, these repositories have a config.json file that includes a quantization_config section. + +-------------------------- + +.. code-block:: python + + from vllm import LLM + import torch + # unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint. + model_id = "unsloth/tinyllama-bnb-4bit" + llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \ + quantization="bitsandbytes", load_format="bitsandbytes") + +Inflight quantization: load as 4bit quantization +------------------------------------------------ + +.. code-block:: python + + from vllm import LLM + import torch + model_id = "huggyllama/llama-7b" + llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \ + quantization="bitsandbytes", load_format="bitsandbytes") + diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 953fd9ba939c8..8b6448d7937c6 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -8,15 +8,20 @@ from tests.quantization.utils import is_quant_method_supported from vllm import SamplingParams +models_to_test = [ + ('huggyllama/llama-7b', 'quantize model inflight'), + ('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'), +] + @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') -def test_load_bnb_model(vllm_runner) -> None: - with vllm_runner('huggyllama/llama-7b', +@pytest.mark.parametrize("model_name, description", models_to_test) +def test_load_bnb_model(vllm_runner, model_name, description) -> None: + with vllm_runner(model_name, quantization='bitsandbytes', load_format='bitsandbytes', enforce_eager=True) as llm: - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 # check the weights in MLP & SelfAttention are quantized to torch.uint8 @@ -72,5 +77,10 @@ def test_load_bnb_model(vllm_runner) -> None: # compare the first line of the output actual_output = outputs[index][1][0].split('\n', 1)[0] expected_output = expected_outputs[index].split('\n', 1)[0] + + # LLM's output should be larger than or equal to expected output + assert len(actual_output) >= len(expected_output) + + actual_output = actual_output[:len(expected_output)] assert actual_output == expected_output, ( f'Expected: {expected_output}, but got: {actual_output}') diff --git a/vllm/config.py b/vllm/config.py index 8d004902fe4ff..8cd8633d50c7c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -524,6 +524,7 @@ class LoadConfig: mainly for profiling. "tensorizer" will use CoreWeave's tensorizer library for fast weight loading. + "bitsandbytes" will load nf4 type weights. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ef31612420c94..63c75ce46f313 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -623,8 +623,8 @@ def create_engine_config(self, ) -> EngineConfig: # bitsandbytes quantization needs a specific model loader # so we make sure the quant method and the load format are consistent if (self.quantization == "bitsandbytes" or - self.qlora_adapter_name_or_path is not None) and \ - self.load_format != "bitsandbytes": + self.qlora_adapter_name_or_path is not None) and \ + self.load_format != "bitsandbytes": raise ValueError( "BitsAndBytes quantization and QLoRA adapter only support " f"'bitsandbytes' load format, but got {self.load_format}") diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 969958d9b5448..007ec2c96bd18 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig): Reference: https://arxiv.org/abs/2305.14314 """ - def __init__( - self, - adapter_name_or_path: str, - target_modules: List[str], - ) -> None: - - self.adapter_name_or_path = adapter_name_or_path - self.target_modules = target_modules + def __init__(self, ) -> None: + pass def __repr__(self) -> str: - return ( - f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}" - ) + return "BitsAndBytesConfig" @classmethod def get_name(self) -> str: @@ -49,16 +41,7 @@ def get_config_filenames() -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": - adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"]) - default_target_modules = [ - "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", - "o_proj" - ] - if adapter_name == "": - target_modules = default_target_modules - else: - target_modules = cls.get_from_keys(config, ["target_modules"]) - return cls(adapter_name, target_modules) + return cls() def get_quant_method( self, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d3babcf9c3451..be8aa8986efe7 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -681,8 +681,14 @@ def _prepare_weights(self, model_name_or_path: str, return hf_weights_files, matched_pattern == "*.safetensors" + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + else: + return pt_weights_iterator(hf_weights_files) + def _get_quantized_weights_iterator( - self, model_name_or_path: str, revision: Optional[str] + self, model_name_or_path: str, revision: Optional[str], pre_quant: bool ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, @@ -691,6 +697,7 @@ def _get_quantized_weights_iterator( # only load the bitsandbytes module when needed try: import bitsandbytes + from bitsandbytes.functional import QuantState if bitsandbytes.__version__ < "0.42.0": raise ImportError("bitsandbytes version is wrong. Please " "install bitsandbytes>=0.42.0.") @@ -704,17 +711,63 @@ def _get_quantized_weights_iterator( model_name_or_path, revision) quant_state_dict = {} - if use_safetensors: - weight_iterator = safetensors_weights_iterator(hf_weights_files) - else: - weight_iterator = pt_weights_iterator(hf_weights_files) - def generator(): + def quantized_checkpoint() -> Generator: + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith(".weight"): + continue + # TODO: only nf4 quantization is supported for now + if weight_name.endswith(".quant_state.bitsandbytes__fp4"): + raise NotImplementedError( + "Only bitsandbytes_nf4 quantization"\ + "is supported right now. {weight_name} is fp4 quantized" + ) + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__nf4 in CPU + quant_state[param_name + + ".quant_state.bitsandbytes__nf4"] = quant_state[ + param_name + + ".quant_state.bitsandbytes__nf4"].cpu().data + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + # Filter out all weights whose suffix is not ".weight" + if not weight_name.endswith(".weight"): + continue + if weight_name + ".quant_state.bitsandbytes__nf4" \ + in temp_state_dict: + quant_state = _parse_quant_state(weight_name, + temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", + ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def generator() -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): if any(target_module in weight_name for target_module in self.target_modules): weight_name = weight_name.replace(".weight", ".qweight") - # bitsandbytes requires data in GPU + # bitsandbytes requires data in GPU loaded_weight = weight_tensor.cuda().data with set_default_torch_dtype(torch.float32): processed_weight, quant_state = quantize_4bit( @@ -728,6 +781,8 @@ def generator(): yield weight_name, processed_weight + if pre_quant: + return quantized_checkpoint(), quant_state_dict return generator(), quant_state_dict def _load_weights(self, model_config: ModelConfig, @@ -745,12 +800,21 @@ def _load_weights(self, model_config: ModelConfig, logger.info("Loading weights with BitsAndBytes quantization. " " May take a while ...") - qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator(model_config.model, - model_config.revision)) + is_quantized_checkpoint = False + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + if quant_config is not None and quant_config.get( + 'quant_method') == "bitsandbytes": + is_quantized_checkpoint = True + + qweight_iterator, quant_state_dict = \ + self._get_quantized_weights_iterator( + model_config.model, model_config.revision, is_quantized_checkpoint) model.load_weights(qweight_iterator) + torch.cuda.empty_cache() + param_dict = dict(model.named_parameters()) stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} for quant_param_name in quant_state_dict: @@ -788,9 +852,9 @@ def _load_weights(self, model_config: ModelConfig, f"pack_factor not set for parameter {param_name}.") num_elements = [0] * len(quant_states) - for seq, quant_state in enumerate(quant_states.items()): + for seq, quant_state in quant_states.items(): num_elements[seq] = math.prod( - quant_state[1].shape) // pack_ratio + quant_state.shape) // pack_ratio offsets = np.concatenate(([0], np.cumsum(num_elements))) set_weight_attrs(param, {"bnb_shard_offsets": offsets}) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 943022a3f03c7..25459a141f3e9 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -117,6 +117,7 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. def get_quant_config(model_config: ModelConfig, load_config: LoadConfig) -> QuantizationConfig: + quant_cls = get_quantization_config(model_config.quantization) # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config",