Skip to content

Commit

Permalink
[Misc] Modify BNB parameter name (vllm-project#9997)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
jeejeelee authored and sumitd2 committed Nov 14, 2024
1 parent c1930f3 commit 9c0a5ed
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ def create_qweight_for_4bit():
qweight = create_qweight_for_8bit()
else:
qweight = create_qweight_for_4bit()

layer.register_parameter("qweight", qweight)
# Enable parameters to have the same name as in the BNB
# checkpoint format.
layer.register_parameter("weight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)

def apply(self,
Expand Down Expand Up @@ -234,7 +235,7 @@ def _apply_8bit_weight(
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)

qweight = layer.qweight
qweight = layer.weight
offsets = qweight.bnb_shard_offsets
quant_states = qweight.bnb_quant_state
matmul_states = qweight.matmul_state
Expand Down Expand Up @@ -313,7 +314,7 @@ def _apply_4bit_weight(
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)

qweight = layer.qweight
qweight = layer.weight
quant_states = qweight.bnb_quant_state
offsets = qweight.bnb_shard_offsets

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(self,
embed_dim,
bias=False,
quant_config=quant_config,
prefix=prefix)
prefix=f"{prefix}.kv_proj")
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
Expand Down
14 changes: 5 additions & 9 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
if not weight_name.lower().endswith(".scb"):
continue

weight_key = weight_name.lower().replace(".scb", ".qweight")
weight_key = weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor

for weight_name, weight_tensor in self._hf_weight_iter(
Expand All @@ -901,11 +901,9 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
if self._is_8bit_weight_name(weight_name):
continue

qweight_name = weight_name.replace(".weight", ".qweight")

if qweight_name in quant_state_dict:
if weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor
yield weight_name, weight_tensor
else:
yield weight_name, weight_tensor

Expand Down Expand Up @@ -950,9 +948,8 @@ def _parse_quant_state(param_name: str,
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
in temp_state_dict):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight", ".qweight"), weight_tensor
yield weight_name, weight_tensor
else:
yield weight_name, weight_tensor

Expand All @@ -967,7 +964,6 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,

if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
weight_name = weight_name.replace(".weight", ".qweight")
# Without sharding
if any(
weight_name.startswith(module)
Expand Down Expand Up @@ -1093,7 +1089,7 @@ def _load_weights(self, model_config: ModelConfig,
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
shard_index = index
quant_param_name = quant_param_name.replace(
Expand Down

0 comments on commit 9c0a5ed

Please sign in to comment.