From 6493ee42f25275aac135f09be03b5572addc684d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 23 Oct 2024 16:19:05 +0000 Subject: [PATCH] Init Signed-off-by: Jee Jee Li --- vllm/model_executor/models/qwen.py | 115 +++++++++++++++++++++++++++-- 1 file changed, 108 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index cd3f7c1b6c4db..319c37937d2df 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -20,7 +20,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, token_inputs) @@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -46,7 +47,7 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of -from .interfaces import SupportsMultiModal, SupportsPP +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -860,11 +861,7 @@ def dummy_data_for_qwen( return seq_data, mm_data -@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) -@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) -@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP): +class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__( self, @@ -872,6 +869,7 @@ def __init__( multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config @@ -990,3 +988,106 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +class QWenLLM(QWenBaseModel): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + +class QWenVL(QWenBaseModel): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field(language_model="llm", + connector="resampler", + tower_model="vpm") + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) +@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) +class QWenLMHeadModel(QWenBaseModel): + """ + QWenLMHeadModel is not only applicable to LLM but also to VL, which is not + conducive to the current integration logic of LoRA in vLLM. Therefore, it + is necessary to separate them. + """ + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__( + cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + if multimodal_config is None: + return QWenLLM(config, multimodal_config, cache_config, + quant_config) + else: + return QWenVL(config, multimodal_config, cache_config, + quant_config)