From 9c0a5edb99d2c519ae5bacf45c2719050bc21b3a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 6 Nov 2024 03:40:08 +0800 Subject: [PATCH] [Misc] Modify BNB parameter name (#9997) Signed-off-by: Jee Jee Li Signed-off-by: Sumit Dubey --- .../layers/quantization/bitsandbytes.py | 9 +++++---- vllm/model_executor/layers/resampler.py | 2 +- vllm/model_executor/model_loader/loader.py | 14 +++++--------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 718967a065192..78965d7b9495c 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -203,8 +203,9 @@ def create_qweight_for_4bit(): qweight = create_qweight_for_8bit() else: qweight = create_qweight_for_4bit() - - layer.register_parameter("qweight", qweight) + # Enable parameters to have the same name as in the BNB + # checkpoint format. + layer.register_parameter("weight", qweight) set_weight_attrs(qweight, extra_weight_attrs) def apply(self, @@ -234,7 +235,7 @@ def _apply_8bit_weight( reshape_after_matmul = True bf_x = x.to(torch.bfloat16) - qweight = layer.qweight + qweight = layer.weight offsets = qweight.bnb_shard_offsets quant_states = qweight.bnb_quant_state matmul_states = qweight.matmul_state @@ -313,7 +314,7 @@ def _apply_4bit_weight( reshape_after_matmul = True bf_x = x.to(torch.bfloat16) - qweight = layer.qweight + qweight = layer.weight quant_states = qweight.bnb_quant_state offsets = qweight.bnb_shard_offsets diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index bce91f1d7fd5e..bca44d2bf2e28 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -177,7 +177,7 @@ def __init__(self, embed_dim, bias=False, quant_config=quant_config, - prefix=prefix) + prefix=f"{prefix}.kv_proj") else: # Maintain the same return value with ReplicatedLinear.forward self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index c3e0290f270ae..1f8d531198324 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -892,7 +892,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, if not weight_name.lower().endswith(".scb"): continue - weight_key = weight_name.lower().replace(".scb", ".qweight") + weight_key = weight_name.lower().replace(".scb", ".weight") quant_state_dict[weight_key] = weight_tensor for weight_name, weight_tensor in self._hf_weight_iter( @@ -901,11 +901,9 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, if self._is_8bit_weight_name(weight_name): continue - qweight_name = weight_name.replace(".weight", ".qweight") - - if qweight_name in quant_state_dict: + if weight_name in quant_state_dict: set_weight_attrs(weight_tensor, {"load_in_8bit": True}) - yield qweight_name, weight_tensor + yield weight_name, weight_tensor else: yield weight_name, weight_tensor @@ -950,9 +948,8 @@ def _parse_quant_state(param_name: str, (f"{weight_name}.quant_state.bitsandbytes__fp4" \ in temp_state_dict): quant_state = _parse_quant_state(weight_name, temp_state_dict) - weight_name = weight_name.replace(".weight", ".qweight") quant_state_dict[weight_name] = quant_state - yield weight_name.replace(".weight", ".qweight"), weight_tensor + yield weight_name, weight_tensor else: yield weight_name, weight_tensor @@ -967,7 +964,6 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, if any(target_module in weight_name for target_module in self.target_modules) and weight_name.endswith(".weight"): - weight_name = weight_name.replace(".weight", ".qweight") # Without sharding if any( weight_name.startswith(module) @@ -1093,7 +1089,7 @@ def _load_weights(self, model_config: ModelConfig, # Some models, such as MiniCPM V2.5/2.6, contain both # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' # from being incorrectly identified as being present in - # 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight + # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".": shard_index = index quant_param_name = quant_param_name.replace(