Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]Reduce BNB static variable #9987

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
Expand Down Expand Up @@ -755,6 +756,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)

# Save the module names without sharding.
self.unsharded_weights_modules: List[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: List[str] = []
# we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules.
Expand All @@ -772,8 +777,6 @@ def __init__(self, load_config: LoadConfig):
with open(config_file_path, "r") as f:
config = json.load(f)
self.target_modules = config["target_modules"]
# Save the module names without sharding.
self.unsharded_weights_modules: List[str] = []

def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter)
Expand Down Expand Up @@ -999,9 +1002,9 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(module in weight_name
for module in self.column_parallel_weights_modules):

elif any(
weight_name.startswith(module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
Expand Down Expand Up @@ -1056,20 +1059,17 @@ def _load_weights(self, model_config: ModelConfig,
else:
self.target_modules = self.default_target_modules

if hasattr(model, 'column_parallel_weights_modules'):
self.column_parallel_weights_modules = \
model.column_parallel_weights_modules
else:
self.column_parallel_weights_modules = []
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
# TODO: Can we reduce the static variables needed for BNB based on
# model information?
self.unsharded_weights_modules = [
name for name, module in model.named_modules()
if isinstance(module, (ReplicatedLinear, ))
]
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if isinstance(module, (ReplicatedLinear, )):
self.unsharded_weights_modules.append(name)
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)

self.model_type = type(model).__name__

logger.info("Loading weights with BitsAndBytes quantization. "
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
".dense_h_to_4h.",
".dense_4h_to_h.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]

def __init__(
self,
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]

jeejeelee marked this conversation as resolved.
Show resolved Hide resolved
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]

bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]

bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
10 changes: 2 additions & 8 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,10 +854,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
# resampler
".kv_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
]

bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down Expand Up @@ -1008,10 +1005,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
# resampler
".kv_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
]

bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,8 +1062,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# so we can't add a dot in front of it.
"multi_modal_projector."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."]

bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".out_proj.", ".fc2."]

def __init__(
self,
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".fc2.", ".dense."]

embedding_modules = {}
embedding_padding_modules = []
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
".o_proj.",
]

# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down