Skip to content

Commit

Permalink
[Bugfix] Fix MiniCPMV and Mllama BNB bug (vllm-project#9917)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Richard Liu <[email protected]>
  • Loading branch information
jeejeelee authored and richardsliu committed Nov 4, 2024
1 parent 122c821 commit 2623641
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 65 deletions.
49 changes: 28 additions & 21 deletions vllm/model_executor/layers/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from torch.nn.init import trunc_normal_

from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig

DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)

Expand Down Expand Up @@ -154,15 +155,15 @@ class BaseResampler(nn.Module):
A tensor with the shape of (grid_size**2, embed_dim)
"""

def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
do_post_projection: bool = True,
) -> None:
def __init__(self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
do_post_projection: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()

self.num_queries = num_queries
Expand All @@ -172,7 +173,11 @@ def __init__(
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
self.kv_proj = ReplicatedLinear(kv_dim,
embed_dim,
bias=False,
quant_config=quant_config,
prefix=prefix)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
Expand Down Expand Up @@ -209,22 +214,24 @@ class Resampler2(BaseResampler):
present in minicpmv2.0, but not qwen-vl.
"""

def __init__(
self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
do_post_projection: bool = True,
) -> None:
def __init__(self,
grid_size: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
adaptive: bool = False,
do_post_projection: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__(grid_size**2,
embed_dim,
num_heads,
kv_dim,
norm_layer,
do_post_projection=do_post_projection)
do_post_projection=do_post_projection,
quant_config=quant_config,
prefix=prefix)

self.adaptive = adaptive
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
Expand Down
34 changes: 28 additions & 6 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
Expand Down Expand Up @@ -771,6 +772,8 @@ def __init__(self, load_config: LoadConfig):
with open(config_file_path, "r") as f:
config = json.load(f)
self.target_modules = config["target_modules"]
# Save the module names without sharding.
self.unsharded_weights_modules: List[str] = []

def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter)
Expand Down Expand Up @@ -990,16 +993,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
weight_name = weight_name.replace(".weight", ".qweight")

if any(module in weight_name
for module in self.column_parallel_weights_modules):
# Without sharding
if any(
weight_name.startswith(module)
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(module in weight_name
for module in self.column_parallel_weights_modules):

total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]

# Shard by row
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
Expand Down Expand Up @@ -1053,7 +1061,15 @@ def _load_weights(self, model_config: ModelConfig,
model.column_parallel_weights_modules
else:
self.column_parallel_weights_modules = []

# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
# TODO: Can we reduce the static variables needed for BNB based on
# model information?
self.unsharded_weights_modules = [
name for name, module in model.named_modules()
if isinstance(module, (ReplicatedLinear, ))
]
self.model_type = type(model).__name__

logger.info("Loading weights with BitsAndBytes quantization. "
Expand Down Expand Up @@ -1100,7 +1116,13 @@ def _load_weights(self, model_config: ModelConfig,
for shard_name, (
weight_name, index
) in model.bitsandbytes_stacked_params_mapping.items():
if shard_name in quant_param_name:

shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
shard_index = index
quant_param_name = quant_param_name.replace(
shard_name, weight_name)
Expand Down
120 changes: 83 additions & 37 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,22 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):

class Resampler2_5(BaseResampler):

def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
) -> None:
super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
def __init__(self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__(num_queries,
embed_dim,
num_heads,
kv_dim,
norm_layer,
quant_config=quant_config,
prefix=prefix)

self.max_size = max_size
self._set_2d_pos_cache(self.max_size)
Expand Down Expand Up @@ -404,7 +410,10 @@ def __init__(
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
self.vpm.embeddings.embed_dim)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.resampler = self.init_resampler(self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix="resampler")
self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
Expand Down Expand Up @@ -666,7 +675,11 @@ def init_vision_module(
) -> nn.Module:
raise NotImplementedError

def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
raise NotImplementedError

def get_vision_embedding(
Expand Down Expand Up @@ -743,16 +756,21 @@ def init_vision_module(
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)

def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2(
embed_dim=embed_dim,
num_heads=embed_dim // 128,
grid_size=int(math.sqrt(self.config.query_num)),
kv_dim=vision_dim,
adaptive=False,
do_post_projection=True,
)
resampler = Resampler2(embed_dim=embed_dim,
num_heads=embed_dim // 128,
grid_size=int(
math.sqrt(self.config.query_num)),
kv_dim=vision_dim,
adaptive=False,
do_post_projection=True,
quant_config=quant_config,
prefix=prefix)

return resampler

Expand Down Expand Up @@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
# Currently, vllm does not support BNB quantization for the `out_proj`
# of the resampler, so it's necessary to distinguish between the
# vision encoder and the resampler's out_proj. The same applies to
# MiniCPMV2_6.
".self_attn.out_proj.", # vision encoder out_proj
# resampler
".kv_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
column_parallel_weights_modules = [
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down Expand Up @@ -877,14 +907,18 @@ def init_vision_module(
model.encoder.layers = model.encoder.layers[:-1]
return model

def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
)
resampler = Resampler2_5(num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler

def get_vision_embedding(
Expand Down Expand Up @@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
".self_attn.out_proj.",
# resampler
".kv_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
column_parallel_weights_modules = [
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down Expand Up @@ -1019,15 +1061,19 @@ def init_vision_module(
model.encoder.layers = model.encoder.layers[:-1]
return model

def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
def init_resampler(self,
embed_dim: int,
vision_dim: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> nn.Module:
with set_default_torch_dtype(torch.float16):
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler = Resampler2_5(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
)
resampler = Resampler2_5(num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler

def get_vision_embedding(
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,9 +1056,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
".k_proj.",
".v_proj.",
".o_proj.",
".fc1.",
".fc2.",
# The `multi_modal_projector` is at the top level of the model,
# so we can't add a dot in front of it.
"multi_modal_projector."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
Expand Down

0 comments on commit 2623641

Please sign in to comment.