Skip to content

Commit

Permalink
[Model][LoRA]LoRA support added for Qwen (vllm-project#9622)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored and Linkun Chen committed Nov 4, 2024
1 parent b6666cc commit 33fdcae
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 14 deletions.
6 changes: 3 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,10 @@ def _filter_unsupported_mm_module(self, module_name: str) -> bool:
be filtered out.
"""
if self.supports_mm:
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector
or prefix in module_mapping.tower_model)
prefix_lst = module_mapping.connector + module_mapping.tower_model
return any(
[module_name.startswith(prefix) for prefix in prefix_lst])
return False

def _register_packed_modules(self, module_full_name: str) -> None:
Expand Down
109 changes: 98 additions & 11 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
Expand All @@ -30,6 +30,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand All @@ -39,14 +40,15 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)

Expand Down Expand Up @@ -122,8 +124,8 @@ def __init__(
# Strided linear layer.
assert self._qkv_same_embed_dim, \
'Visual Attention implementation only supports self-attention'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)

def forward(
Expand All @@ -133,7 +135,7 @@ def forward(
) -> torch.Tensor:
# query/key/value: [sq, b, h]
sq, b, _ = x.size()
mixed_x_layer = self.in_proj(x)
mixed_x_layer, _ = self.in_proj(x)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
Expand Down Expand Up @@ -182,7 +184,7 @@ def forward(
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)

output = self.out_proj(context_layer)
output, _ = self.out_proj(context_layer)

return output

Expand Down Expand Up @@ -860,18 +862,15 @@ def dummy_data_for_qwen(
return seq_data, mm_data


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

def __init__(
self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
Expand Down Expand Up @@ -990,3 +989,91 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)


class QWenLLM(QWenBaseModel):
packed_modules_mapping = {
"c_attn": ["c_attn"],
"gate_up_proj": [
"w2",
"w1",
],
}
# LoRA specific attributes
supported_lora_modules = [
"c_attn",
"gate_up_proj",
"c_proj",
]

embedding_modules = {}
embedding_padding_modules = []


class QWenVL(QWenBaseModel):
packed_modules_mapping = {
"c_attn": ["c_attn"],
"gate_up_proj": [
"w2",
"w1",
],
}
# LoRA specific attributes
supported_lora_modules = [
"c_attn",
"gate_up_proj",
"c_proj",
# visual module
"out_proj",
"in_proj",
"c_fc",
# resampler
"kv_proj",
]

embedding_modules = {}
embedding_padding_modules = []

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="transformer.h",
connector="transformer.visual.attn_pool",
tower_model="transformer.visual.transformer")


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(QWenBaseModel):
"""
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
conducive to the current integration logic of LoRA in vLLM. Therefore, it
is necessary to separate them.
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
embedding_padding_modules = []

def __new__(
cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
# Initialize VL
if hasattr(config, "visual"):
return QWenVL(config, multimodal_config, cache_config,
quant_config, lora_config)
# Initialize LLM
else:
return QWenLLM(config, multimodal_config, cache_config,
quant_config, lora_config)

0 comments on commit 33fdcae

Please sign in to comment.