diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 35be1cec3d434..8173b780d49ab 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -97,6 +97,37 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: return embeddings +class InternVisionPatchModel(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embeddings = InternVisionEmbeddings(config) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if pixel_values is None and pixel_embeds is None: + raise ValueError( + 'You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + elif pixel_values is not None: + if pixel_values.ndim == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError( + f'wrong pixel_values size: {pixel_values.shape}') + + return hidden_states + + class InternParallelAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py new file mode 100644 index 0000000000000..dd85b5dfe168b --- /dev/null +++ b/vllm/model_executor/models/internlm2_ve.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) +from vllm.model_executor.models.internlm2 import InternLM2Attention, InternLM2MLP, InternLM2Model, InternLM2ForCausalLM + +class InternLM2VEDecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attention = InternLM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + self.feed_forward = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.feed_forward_ve = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.attention_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + if visual_token_mask is not None and visual_token_mask.any(): + visual_token_mask=visual_token_mask.repeat(1, self.hidden_size).bool() + hidden_states[visual_token_mask] = self.feed_forward_ve(hidden_states[visual_token_mask].reshape(-1,self.hidden_size)).flatten() + hidden_states[~visual_token_mask] = self.feed_forward(hidden_states[~visual_token_mask].reshape(-1,self.hidden_size)).flatten() + else: + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class InternLM2VEModel(InternLM2Model): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, cache_config, quant_config) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: InternLM2VEDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + visual_token_mask=visual_token_mask, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class InternLM2VEForCausalLM(InternLM2ForCausalLM): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__(config, cache_config, quant_config) + self.model = InternLM2VEModel(config, cache_config, quant_config) \ No newline at end of file diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index aada92cdf2456..52b64b87a7eac 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -21,7 +21,7 @@ token_inputs) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.models.intern_vit import InternVisionModel +from vllm.model_executor.models.intern_vit import InternVisionModel, InternVisionPatchModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -427,13 +427,9 @@ def __init__(self, self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version - vision_feature_layer = self.select_layer - if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 - else: - num_hidden_layers = vision_feature_layer + 1 - self.vision_model = self._init_vision_model(config, num_hidden_layers) + self.llm_arch_name = config.text_config.architectures[0] + self.is_mono = self.llm_arch_name=='InternLM2VEForCausalLM' + self.vision_model = self._init_vision_model(config, self.is_mono) self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) @@ -451,10 +447,18 @@ def sampler(self): return Sampler() - def _init_vision_model(self, config: PretrainedConfig, - num_hidden_layers: int): - return InternVisionModel(config.vision_config, - num_hidden_layers_override=num_hidden_layers) + def _init_vision_model(self, config: PretrainedConfig, is_mono: bool): + if not is_mono: + vision_feature_layer = self.select_layer + if vision_feature_layer < 0: + num_hidden_layers = config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + return InternVisionModel(config.vision_config, + num_hidden_layers_override=num_hidden_layers) + else: + return InternVisionPatchModel(config.vision_config) def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: vit_hidden_size = config.vision_config.hidden_size @@ -574,6 +578,7 @@ def forward( if intermediate_tensors is not None: input_ids = None inputs_embeds = None + visual_token_mask = None else: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: @@ -583,16 +588,27 @@ def forward( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.img_context_token_id) + visual_token_mask = (input_ids==self.img_context_token_id).reshape(-1, 1) input_ids = None else: inputs_embeds = None + visual_token_mask = None - hidden_states = self.language_model.model(input_ids, + if self.is_mono: + hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + visual_token_mask=visual_token_mask) + else: + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f442ce0f63e3e..c0d0cabb6c1df 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -47,6 +47,7 @@ "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),