-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# -*- coding: utf-8 -*- | ||
from functools import partial | ||
Check failure on line 2 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 2 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 2 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
|
||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | ||
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.12)Ruff (F401)
Check failure on line 3 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.12)Ruff (F401)
|
||
|
||
import torch | ||
from torch import nn | ||
from transformers import PretrainedConfig | ||
|
||
from vllm.attention import Attention, AttentionMetadata | ||
Check failure on line 9 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 9 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 9 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
|
||
from vllm.config import CacheConfig | ||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, | ||
Check failure on line 11 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 11 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 11 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
|
||
get_tensor_model_parallel_world_size, | ||
Check failure on line 12 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 12 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 12 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
|
||
split_tensor_along_last_dim, | ||
Check failure on line 13 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 13 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 13 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
|
||
tensor_model_parallel_all_gather) | ||
Check failure on line 14 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 14 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 14 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
|
||
from vllm.model_executor.layers.activation import SiluAndMul | ||
Check failure on line 15 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.11)Ruff (F401)
Check failure on line 15 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.10)Ruff (F401)
Check failure on line 15 in vllm/model_executor/models/internlm2_ve.py GitHub Actions / ruff (3.9)Ruff (F401)
|
||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, | ||
QKVParallelLinear, | ||
RowParallelLinear) | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.quantization import QuantizationConfig | ||
from vllm.model_executor.layers.rotary_embedding import get_rope | ||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
ParallelLMHead, VocabParallelEmbedding) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.sequence import IntermediateTensors | ||
|
||
from .interfaces import SupportsPP | ||
from .utils import (is_pp_missing_parameter, | ||
make_empty_intermediate_tensors_factory, make_layers) | ||
from vllm.model_executor.models.internlm2 import InternLM2Attention, InternLM2MLP, InternLM2Model, InternLM2ForCausalLM | ||
|
||
class InternLM2VEDecoderLayer(nn.Module): | ||
def __init__( | ||
self, | ||
config: PretrainedConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
) -> None: | ||
super().__init__() | ||
self.hidden_size = config.hidden_size | ||
rope_theta = getattr(config, "rope_theta", 10000) | ||
rope_scaling = getattr(config, "rope_scaling", None) | ||
max_position_embeddings = getattr(config, "max_position_embeddings", | ||
8192) | ||
self.attention = InternLM2Attention( | ||
hidden_size=self.hidden_size, | ||
num_heads=config.num_attention_heads, | ||
num_kv_heads=config.num_key_value_heads, | ||
rope_theta=rope_theta, | ||
rope_scaling=rope_scaling, | ||
max_position_embeddings=max_position_embeddings, | ||
cache_config=cache_config, | ||
quant_config=quant_config, | ||
) | ||
self.feed_forward = InternLM2MLP( | ||
hidden_size=self.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
hidden_act=config.hidden_act, | ||
quant_config=quant_config, | ||
) | ||
self.feed_forward_ve = InternLM2MLP( | ||
hidden_size=self.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
hidden_act=config.hidden_act, | ||
quant_config=quant_config, | ||
) | ||
self.attention_norm = RMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
kv_cache: torch.Tensor, | ||
attn_metadata: AttentionMetadata, | ||
residual: Optional[torch.Tensor], | ||
visual_token_mask: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# Self Attention | ||
if residual is None: | ||
residual = hidden_states | ||
hidden_states = self.attention_norm(hidden_states) | ||
else: | ||
hidden_states, residual = self.attention_norm( | ||
hidden_states, residual) | ||
hidden_states = self.attention( | ||
positions=positions, | ||
hidden_states=hidden_states, | ||
kv_cache=kv_cache, | ||
attn_metadata=attn_metadata, | ||
) | ||
|
||
# Fully Connected | ||
hidden_states, residual = self.ffn_norm(hidden_states, residual) | ||
if visual_token_mask is not None and visual_token_mask.any(): | ||
visual_token_mask=visual_token_mask.repeat(1, self.hidden_size).bool() | ||
hidden_states[visual_token_mask] = self.feed_forward_ve(hidden_states[visual_token_mask].reshape(-1,self.hidden_size)).flatten() | ||
hidden_states[~visual_token_mask] = self.feed_forward(hidden_states[~visual_token_mask].reshape(-1,self.hidden_size)).flatten() | ||
else: | ||
hidden_states = self.feed_forward(hidden_states) | ||
return hidden_states, residual | ||
|
||
|
||
class InternLM2VEModel(InternLM2Model): | ||
def __init__( | ||
self, | ||
config: PretrainedConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
prefix: str = "", | ||
) -> None: | ||
super().__init__(config, cache_config, quant_config) | ||
self.start_layer, self.end_layer, self.layers = make_layers( | ||
config.num_hidden_layers, | ||
lambda prefix: InternLM2VEDecoderLayer(config, cache_config, | ||
quant_config), | ||
prefix=f"{prefix}.layers") | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
visual_token_mask: Optional[torch.Tensor] = None, | ||
) -> Union[torch.Tensor, IntermediateTensors]: | ||
if get_pp_group().is_first_rank: | ||
if inputs_embeds is not None: | ||
hidden_states = inputs_embeds | ||
else: | ||
hidden_states = self.tok_embeddings(input_ids) | ||
residual = None | ||
else: | ||
assert intermediate_tensors is not None | ||
hidden_states = intermediate_tensors["hidden_states"] | ||
residual = intermediate_tensors["residual"] | ||
for i in range(self.start_layer, self.end_layer): | ||
layer = self.layers[i] | ||
hidden_states, residual = layer( | ||
positions, | ||
hidden_states, | ||
kv_caches[i - self.start_layer], | ||
attn_metadata, | ||
residual, | ||
visual_token_mask=visual_token_mask, | ||
) | ||
if not get_pp_group().is_last_rank: | ||
return IntermediateTensors({ | ||
"hidden_states": hidden_states, | ||
"residual": residual | ||
}) | ||
hidden_states, _ = self.norm(hidden_states, residual) | ||
return hidden_states | ||
|
||
|
||
class InternLM2VEForCausalLM(InternLM2ForCausalLM): | ||
def __init__( | ||
self, | ||
config: PretrainedConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
) -> None: | ||
super().__init__(config, cache_config, quant_config) | ||
self.model = InternLM2VEModel(config, cache_config, quant_config) |