diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 3f6d8ef42cd5f..3457ec6b8e73b 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -295,16 +295,29 @@ ) ], ), - "minicpmv": VLMTestInfo( + "minicpmv_25": VLMTestInfo( models=["openbmb/MiniCPM-Llama3-V-2_5"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + test_type=VLMTestType.IMAGE, prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 img_idx_to_prompt=lambda idx: "(./)\n", max_model_len=4096, max_num_seqs=2, get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id], postprocess_inputs=model_utils.wrap_inputs_post_processor, - hf_output_post_proc=model_utils.minicmpv_trunc_hf_output, + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, + ), + "minicpmv_26": VLMTestInfo( + models=["openbmb/MiniCPM-V-2_6"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "(./)\n", + max_model_len=4096, + max_num_seqs=2, + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + postprocess_inputs=model_utils.ignore_inputs_post_processor( + "image_sizes" + ), + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 849857b4232e7..15f15dd7d8030 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -170,7 +170,7 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, ####### Post-processors for HF outputs -def minicmpv_trunc_hf_output(hf_output: RunnerOutput, +def minicpmv_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|eot_id|>"): @@ -197,6 +197,17 @@ def process(hf_inputs: BatchEncoding, dtype: str): return process +def ignore_inputs_post_processor( + hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: + """Gets a handle to a post processor which ignores a given key.""" + + def process(hf_inputs: BatchEncoding, dtype: str): + del hf_inputs[hf_inp_key] + return hf_inputs + + return process + + def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str): return {"model_inputs": hf_inputs} diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5570771ac917b..8c6f7c6e06515 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -242,7 +242,7 @@ def _load_per_tensor_weight_scale(self, shard_id: str, def _load_model_weight_or_group_weight_scale(self, shard_dim: int, expert_data: torch.Tensor, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int): # Load grouped weight scales for group quantization # or model weights @@ -261,7 +261,7 @@ def _load_model_weight_or_group_weight_scale(self, shard_dim: int, def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int): # for per channel weight quantization if shard_id == "w2": @@ -274,7 +274,7 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, tp_rank=tp_rank) def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim @@ -292,7 +292,7 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, expert_data.copy_(loaded_weight) def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim @@ -311,7 +311,7 @@ def _load_single_value(self, param: torch.nn.Parameter, param_data[expert_id] = loaded_weight def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, - shard_dim: int, loaded_weight: torch.tensor, tp_rank: int): + shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): if shard_id == "w2": self._load_w2(shard_id=shard_id, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index c9a573278a136..6254d26c7060d 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -52,7 +52,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -378,6 +378,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size, org_num_embeddings=config.vocab_size, ) + self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( @@ -437,6 +438,73 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -480,8 +548,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.cache_config = cache_config self.quant_config = quant_config - self.num_experts = getattr(self.config, "num_experts", 0) - self._init_model(vllm_config=vllm_config, prefix=prefix) + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -506,8 +575,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model.make_empty_intermediate_tensors) def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): - self.model = MiniCPMModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -546,72 +614,9 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - expert_params_mapping = [ - # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.num_experts) - for weight_name in ["w1", "w2", "w3"] - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, expert_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index c66be2d9c2d07..e9d7eada1d16c 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -40,7 +40,7 @@ MiniCPMForCausalLM, MiniCPMModel) -from .utils import make_layers, maybe_prefix +from .utils import make_layers class MiniCPM3Attention(nn.Module): @@ -248,5 +248,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): } def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): - self.model = MiniCPM3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index aacce477e0460..1e8f9bd4cf418 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -22,7 +22,7 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re -from functools import partial +from functools import cached_property, partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -37,19 +37,15 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import LlamaModel -from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.models.utils import LLMWrapper +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor @@ -58,11 +54,7 @@ from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import is_pp_missing_parameter, maybe_prefix - -_KEYS_TO_MODIFY_MAPPING = { - "llm.lm_head": "lm_head", -} +from .utils import AutoWeightsLoader, maybe_prefix RawImageType = Union[Image.Image, torch.Tensor] @@ -297,10 +289,9 @@ def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): def get_placeholder(image_size: Tuple[int, int], num_image: int): if version == (2, 0) or version == (2, 5): - return image_processor. \ - get_slice_image_placeholder(image_size) - return image_processor. \ - get_slice_image_placeholder(image_size, num_image) + return image_processor.get_slice_image_placeholder(image_size) + return image_processor.get_slice_image_placeholder( + image_size, num_image) prompt = inputs.get("prompt") token_ids = inputs.get("prompt_token_ids") @@ -400,37 +391,32 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vpm = self.init_vision_module(config, quant_config, prefix=maybe_prefix(prefix, "vpm")) - param_dtype = torch.get_default_dtype() - self.vpm.to(dtype=param_dtype) self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size + self.resampler = self.init_resampler(self.embed_dim, self.vision_dim, quant_config=quant_config, prefix=maybe_prefix( prefix, "resampler")) - self.resampler.to(device="cuda", dtype=param_dtype) - # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "llm.lm_head")) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.llm.make_empty_intermediate_tensors) + @cached_property + def sampler(self): + if hasattr(self.llm, "sampler"): + return self.llm.sampler + + return get_sampler() + def get_embedding( self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], ) -> Tuple[torch.Tensor, torch.Tensor]: - vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) - if hasattr(self.config, "scale_emb"): - vlm_embedding *= self.config.scale_emb + vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) if image_inputs is None: # No image vision_hidden_states = torch.tensor([], device=input_ids.device) @@ -575,7 +561,7 @@ def forward( # for `torch.compile` integration input_ids = None - output = self.llm( + output = self.llm.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, @@ -590,9 +576,7 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.llm.compute_logits(hidden_states, sampling_metadata) def sample( self, @@ -604,52 +588,8 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - use_default_weight_loading = False - if self.is_default_weight_loading(name): - use_default_weight_loading = True - else: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading: - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) def get_mm_mapping(self) -> MultiModelKeys: """ @@ -693,9 +633,6 @@ def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError - def is_default_weight_loading(self, name: str) -> bool: - raise NotImplementedError - class MiniCPMV2_0(MiniCPMVBaseModel): @@ -708,8 +645,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix), - name="model") + return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -717,11 +653,12 @@ def init_vision_module( quant_config: Optional[QuantizationConfig], prefix: str = "", ) -> nn.Module: - # TODO :refactor this vision model + # TODO: refactor this vision model try: import timm except ImportError: raise ImportError("Please install timm==0.9.10") from ImportError + with set_default_torch_dtype(torch.float16): model = timm.create_model( "vit_so400m_patch14_siglip_384.webli", @@ -731,6 +668,8 @@ def init_vision_module( dynamic_img_pad=True, ) + model = model.to(dtype=torch.get_default_dtype()) + if (isinstance(model, timm.models.VisionTransformer) and model.attn_pool is not None): model.attn_pool = torch.nn.Identity() @@ -759,7 +698,7 @@ def init_resampler(self, quant_config=quant_config, prefix=prefix) - return resampler + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -790,9 +729,6 @@ def get_vision_hidden_states(self, return self.get_vision_embedding(pixel_values) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name or "vpm" in name - class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): packed_modules_mapping = { @@ -843,8 +779,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix), - name="model") + return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -871,7 +806,8 @@ def init_resampler(self, kv_dim=vision_dim, quant_config=quant_config, prefix=prefix) - return resampler + + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -913,9 +849,6 @@ def get_vision_hidden_states(self, return self.get_vision_embedding(all_pixel_values.type(dtype), patch_attn_mask, tgt_sizes) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name - class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): packed_modules_mapping = { @@ -966,8 +899,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix), - name="model") + return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -995,7 +927,8 @@ def init_resampler(self, kv_dim=vision_dim, quant_config=quant_config, prefix=prefix) - return resampler + + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -1043,9 +976,6 @@ def get_vision_hidden_states(self, return self.resampler(vision_embedding, tgt_sizes) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name - _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 4c13cbc953273..a6b40a233439b 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass, field -from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Protocol, Set, Tuple, Union, overload) +from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, Union, overload) import torch import torch.nn as nn @@ -560,30 +560,6 @@ def make_empty_intermediate_tensors( return make_empty_intermediate_tensors -class LLMWrapper(nn.Module): - """ - To align with the key names of LoRA trained with PEFT, we need to add an - additional layer to the llm's implementation. - """ - - def __init__(self, llm: nn.Module, name: str) -> None: - super().__init__() - self.model_name = name - setattr(self, name, llm) - - def __getattr__(self, key: str): - llm = super().__getattr__(self.model_name) - if key == self.model_name: - return llm - - return getattr(llm, key) - - # We need to explicitly override this - def __call__(self, *args: Any, **kwargs: Any) -> Any: - llm = super().__getattr__(self.model_name) - return llm(*args, **kwargs) - - def get_vit_attn_backend(support_fa: bool = False) -> _Backend: """ Get the available attention backend for Vision Transformer.