From 7a4df5f200f0943113dd2d9be49cbcae38ad10bb Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 29 Oct 2024 12:14:07 +0800 Subject: [PATCH] [Model][LoRA]LoRA support added for Qwen (#9622) Signed-off-by: Jee Jee Li --- vllm/lora/models.py | 6 +- vllm/model_executor/models/qwen.py | 109 ++++++++++++++++++++++++++--- 2 files changed, 101 insertions(+), 14 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index aaadca9a4d16d..d0279f273db7a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -578,10 +578,10 @@ def _filter_unsupported_mm_module(self, module_name: str) -> bool: be filtered out. """ if self.supports_mm: - prefix = module_name.split(".")[0] module_mapping: MultiModelKeys = self.model.get_mm_mapping() - return (prefix in module_mapping.connector - or prefix in module_mapping.tower_model) + prefix_lst = module_mapping.connector + module_mapping.tower_model + return any( + [module_name.startswith(prefix) for prefix in prefix_lst]) return False def _register_packed_modules(self, module_full_name: str) -> None: diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index cd3f7c1b6c4db..0a1b40927e9f9 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -20,7 +20,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, token_inputs) @@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -39,6 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -46,7 +48,7 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -122,8 +124,8 @@ def __init__( # Strided linear layer. assert self._qkv_same_embed_dim, \ 'Visual Attention implementation only supports self-attention' - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) - self.out_proj = nn.Linear(embed_dim, embed_dim) + self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim) + self.out_proj = ReplicatedLinear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) def forward( @@ -133,7 +135,7 @@ def forward( ) -> torch.Tensor: # query/key/value: [sq, b, h] sq, b, _ = x.size() - mixed_x_layer = self.in_proj(x) + mixed_x_layer, _ = self.in_proj(x) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ @@ -182,7 +184,7 @@ def forward( (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) - output = self.out_proj(context_layer) + output, _ = self.out_proj(context_layer) return output @@ -860,11 +862,7 @@ def dummy_data_for_qwen( return seq_data, mm_data -@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) -@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) -@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP): +class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__( self, @@ -872,6 +870,7 @@ def __init__( multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config @@ -990,3 +989,91 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +class QWenLLM(QWenBaseModel): + packed_modules_mapping = { + "c_attn": ["c_attn"], + "gate_up_proj": [ + "w2", + "w1", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "c_attn", + "gate_up_proj", + "c_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + +class QWenVL(QWenBaseModel): + packed_modules_mapping = { + "c_attn": ["c_attn"], + "gate_up_proj": [ + "w2", + "w1", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "c_attn", + "gate_up_proj", + "c_proj", + # visual module + "out_proj", + "in_proj", + "c_fc", + # resampler + "kv_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.h", + connector="transformer.visual.attn_pool", + tower_model="transformer.visual.transformer") + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) +@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) +class QWenLMHeadModel(QWenBaseModel): + """ + QWenLMHeadModel is not only applicable to LLM but also to VL, which is not + conducive to the current integration logic of LoRA in vLLM. Therefore, it + is necessary to separate them. + """ + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__( + cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + # Initialize VL + if hasattr(config, "visual"): + return QWenVL(config, multimodal_config, cache_config, + quant_config, lora_config) + # Initialize LLM + else: + return QWenLLM(config, multimodal_config, cache_config, + quant_config, lora_config)