From 1f958a7d52b24314e41c4bb56c51b1dce5405e05 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 5 Dec 2024 13:20:26 +0800 Subject: [PATCH] [Bugfix] Fix BNB loader target_modules (#10720) Signed-off-by: Jee Jee Li --- vllm/model_executor/model_loader/loader.py | 64 ++-------------------- 1 file changed, 6 insertions(+), 58 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b4921cc80797f..a0ea0e5fad3c2 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -6,7 +6,6 @@ import glob import inspect import itertools -import json import math import os import warnings @@ -18,7 +17,7 @@ import huggingface_hub import numpy as np import torch -from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub import HfApi from torch import nn from transformers import AutoModelForCausalLM from transformers.utils import SAFE_WEIGHTS_INDEX_NAME @@ -704,51 +703,9 @@ def __init__(self, load_config: LoadConfig): self.unsharded_weights_modules: List[str] = [] # Save the module names that are sharded by column. self.column_sharded_weights_modules: List[str] = [] - # we don't need to quantize the whole model, only the target modules - # that are specified in the adapter config file. If the adapter config - # file is not provided, we will quantize the default modules. - if (not load_config.model_loader_extra_config - or "qlora_adapter_name_or_path" - not in load_config.model_loader_extra_config): - self.target_modules = [] - return - - qlora_adapter = load_config.model_loader_extra_config[ - "qlora_adapter_name_or_path"] - - config_file_path = self._get_config_file(qlora_adapter) - - with open(config_file_path) as f: - config = json.load(f) - self.target_modules = config["target_modules"] - # TODO: target_modules could be either a list or a regex string. - # We need to handle both cases. - assert isinstance(self.target_modules, - list), "Unsupported target_modules: " - f"{self.target_modules}" - - def _get_config_file(self, qlora_adapter: str) -> str: - is_local = os.path.isdir(qlora_adapter) - config_file_path = None - if is_local: - for file in self.possible_config_file_names: - config_file_path = os.path.join(qlora_adapter, file) - if os.path.exists(config_file_path): - break - else: - hf_api = HfApi() - repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) - for file in self.possible_config_file_names: - if file in repo_files: - config_file_path = hf_hub_download(repo_id=qlora_adapter, - filename=file) - break - - if not config_file_path: - raise ValueError( - f"Cannot find adapter config file in {qlora_adapter}") - - return config_file_path + # Store all module names (from transformers) that support + # BNB quantization. + self.target_modules: List[str] = [] def _get_weight_files( self, @@ -1030,25 +987,16 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: inverse_stacked_mapping[packed] = [] inverse_stacked_mapping[packed].insert(idx, orig) - linear_module_lst = [] for name, module in model.named_modules(): if isinstance(module, (LinearBase, )): last_name = name.split(".")[-1] if sub_modules := inverse_stacked_mapping.get(last_name, []): # Map vllm's names to transformers' names. for sub_name in sub_modules: - linear_module_lst.append( + self.target_modules.append( name.replace(last_name, sub_name)) else: - linear_module_lst.append(name) - if self.target_modules: - # Update self.target_modules - self.target_modules = [ - qual_name for qual_name in linear_module_lst - if any(t in qual_name for t in self.target_modules) - ] - else: - self.target_modules = linear_module_lst + self.target_modules.append(name) assert (self.target_modules ), "vllm currently does not support BNB quantization for" f" {type(model).__name__}"