Skip to content

Commit

Permalink
Complete QWenVL support LoRA
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee committed Oct 28, 2024
1 parent 6493ee4 commit 6462961
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 52 deletions.
25 changes: 22 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.module_mapping import (ModelComposeMethod,
MultiModelKeys)
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.utils import is_pin_memory_available

Expand Down Expand Up @@ -577,11 +578,29 @@ def _filter_unsupported_mm_module(self, module_name: str) -> bool:
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
module_mapping: MultiModelKeys = self.model.get_mm_mapping()

def _verify_decoupled_model():
"""
Suitable for MiniCPMV, InternVL, etc.
"""
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector
or prefix in module_mapping.tower_model)

def _verify_coupled_model():
"""
Suitable for QWenVL, GLM4V, etc.
"""
prefix_lst = module_mapping.connector + module_mapping.tower_model
return any(
[module_name.startswith(prefix) for prefix in prefix_lst])

if self.supports_mm:
if module_mapping.compose_type == ModelComposeMethod.Decoupled:
return _verify_decoupled_model()
else:
return _verify_coupled_model()
return False

def _register_packed_modules(self, module_full_name: str) -> None:
Expand Down
11 changes: 7 additions & 4 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.module_mapping import (ModelComposeMethod,
MultiModelKeys)
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand Down Expand Up @@ -635,9 +636,11 @@ def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(language_model="llm",
connector="resampler",
tower_model="vpm")
return MultiModelKeys.from_string_field(
language_model="llm",
connector="resampler",
tower_model="vpm",
compose_type=ModelComposeMethod.Decoupled)

def init_llm(
self,
Expand Down
42 changes: 42 additions & 0 deletions vllm/model_executor/models/module_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,46 @@
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py

from dataclasses import dataclass, field
from enum import IntEnum
from typing import List, Union


class ModelComposeMethod(IntEnum):
"""
`ModelComposeMethod` distinguishes between two architectural patterns in
multi-modal models, focusing on how vision model, language model, and
projector are implemented:
1. Decoupled Architecture (like mllama, InternVL, miniCPMV), complete
independent implementation with its own layers, for example:
```
InternVLChatModel
├── vision_model (visual encoder)
│ ├── embeddings
│ └── encoder
├── language_model (language model)
│ ├── tok_embeddings
│ └── layers
└── mlp1 (projector)
```
2. Coupled Architecture (like QWenVL, GLM4V), Integrated as a sub-module
with shared architectural patterns , for example:
```
QWenVL
└── transformer
├── wte
├── h (language model)
├── ln_f
└── visual (visual encoder)
├── conv1
├── transformer
└── attn_pool (projector)
```
"""
Decoupled = 0
Coupled = 1


@dataclass
class ModelKeys:
model_type: str = None
Expand Down Expand Up @@ -41,6 +78,8 @@ class ModelKeys:

output: str = None

compose_type: str = None


@dataclass
class MultiModelKeys(ModelKeys):
Expand All @@ -55,7 +94,9 @@ def from_string_field(language_model: Union[str, List[str]] = None,
connector: Union[str, List[str]] = None,
tower_model: Union[str, List[str]] = None,
generator: Union[str, List[str]] = None,
compose_type: str = None,
**kwargs) -> 'MultiModelKeys':
assert compose_type, "compose_type is not allowed to be None"

def to_list(value):
if value is None:
Expand All @@ -66,4 +107,5 @@ def to_list(value):
connector=to_list(connector),
tower_model=to_list(tower_model),
generator=to_list(generator),
compose_type=compose_type,
**kwargs)
78 changes: 33 additions & 45 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand All @@ -39,7 +40,8 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.module_mapping import (ModelComposeMethod,
MultiModelKeys)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
Expand Down Expand Up @@ -123,8 +125,8 @@ def __init__(
# Strided linear layer.
assert self._qkv_same_embed_dim, \
'Visual Attention implementation only supports self-attention'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)

def forward(
Expand All @@ -134,7 +136,7 @@ def forward(
) -> torch.Tensor:
# query/key/value: [sq, b, h]
sq, b, _ = x.size()
mixed_x_layer = self.in_proj(x)
mixed_x_layer, _ = self.in_proj(x)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
Expand Down Expand Up @@ -183,7 +185,7 @@ def forward(
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)

output = self.out_proj(context_layer)
output, _ = self.out_proj(context_layer)

return output

Expand Down Expand Up @@ -992,29 +994,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

class QWenLLM(QWenBaseModel):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"c_attn": ["c_attn"],
"gate_up_proj": [
"gate_proj",
"up_proj",
"w2",
"w1",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"c_attn",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
"c_proj",
]

embedding_modules = {}
Expand All @@ -1023,27 +1013,21 @@ class QWenLLM(QWenBaseModel):

class QWenVL(QWenBaseModel):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"c_attn": ["c_attn"],
"gate_up_proj": [
"gate_proj",
"up_proj",
"w2",
"w1",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"c_attn",
"gate_up_proj",
"down_proj",
"c_proj",
# visual module
"out_proj",
"in_proj",
"c_fc",
# resampler
"kv_proj",
]
Expand All @@ -1055,9 +1039,11 @@ def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(language_model="llm",
connector="resampler",
tower_model="vpm")
return MultiModelKeys.from_string_field(
language_model="transformer.h",
connector="transformer.visual.attn_pool",
tower_model="transformer.visual.transformer",
compose_type=ModelComposeMethod.Coupled)


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
Expand Down Expand Up @@ -1085,9 +1071,11 @@ def __new__(
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
if multimodal_config is None:
return QWenLLM(config, multimodal_config, cache_config,
quant_config)
else:
# Initialize VL
if hasattr(config, "visual"):
return QWenVL(config, multimodal_config, cache_config,
quant_config)
quant_config, lora_config)
# Initialize LLM
else:
return QWenLLM(config, multimodal_config, cache_config,
quant_config, lora_config)

0 comments on commit 6462961

Please sign in to comment.