Skip to content

Commit

Permalink
[Bugfix] Fix BNB loader target_modules (vllm-project#10720)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Dec 5, 2024
1 parent aa39a8e commit 1f958a7
Showing 1 changed file with 6 additions and 58 deletions.
64 changes: 6 additions & 58 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import glob
import inspect
import itertools
import json
import math
import os
import warnings
Expand All @@ -18,7 +17,7 @@
import huggingface_hub
import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub import HfApi
from torch import nn
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
Expand Down Expand Up @@ -704,51 +703,9 @@ def __init__(self, load_config: LoadConfig):
self.unsharded_weights_modules: List[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: List[str] = []
# we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules.
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
self.target_modules = []
return

qlora_adapter = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"]

config_file_path = self._get_config_file(qlora_adapter)

with open(config_file_path) as f:
config = json.load(f)
self.target_modules = config["target_modules"]
# TODO: target_modules could be either a list or a regex string.
# We need to handle both cases.
assert isinstance(self.target_modules,
list), "Unsupported target_modules: "
f"{self.target_modules}"

def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter)
config_file_path = None
if is_local:
for file in self.possible_config_file_names:
config_file_path = os.path.join(qlora_adapter, file)
if os.path.exists(config_file_path):
break
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
for file in self.possible_config_file_names:
if file in repo_files:
config_file_path = hf_hub_download(repo_id=qlora_adapter,
filename=file)
break

if not config_file_path:
raise ValueError(
f"Cannot find adapter config file in {qlora_adapter}")

return config_file_path
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: List[str] = []

def _get_weight_files(
self,
Expand Down Expand Up @@ -1030,25 +987,16 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
inverse_stacked_mapping[packed] = []
inverse_stacked_mapping[packed].insert(idx, orig)

linear_module_lst = []
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []):
# Map vllm's names to transformers' names.
for sub_name in sub_modules:
linear_module_lst.append(
self.target_modules.append(
name.replace(last_name, sub_name))
else:
linear_module_lst.append(name)
if self.target_modules:
# Update self.target_modules
self.target_modules = [
qual_name for qual_name in linear_module_lst
if any(t in qual_name for t in self.target_modules)
]
else:
self.target_modules = linear_module_lst
self.target_modules.append(name)
assert (self.target_modules
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"
Expand Down

0 comments on commit 1f958a7

Please sign in to comment.