From 06aea5da8c122e9a0c5d29ea2898b44338cdf31d Mon Sep 17 00:00:00 2001 From: Zhaokai Wang <53330871+wzk1015@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:09:44 +0800 Subject: [PATCH] Support Mono-InternVL with PyTorch backend (#2727) * support Mono-InternVL; fix typos * update readme * add assertion for FP16 * add assertion for FP16 * update _SUPPORTED_ARCHS --- .github/CONTRIBUTING.md | 14 +- README.md | 2 + README_zh-CN.md | 2 + docs/en/multi_modal/internvl.md | 13 +- docs/en/multi_modal/vl_pipeline.md | 1 + docs/en/supported_models/supported_models.md | 5 + docs/zh_cn/multi_modal/internvl.md | 17 +- docs/zh_cn/multi_modal/vl_pipeline.md | 1 + .../supported_models/supported_models.md | 5 + lmdeploy/model.py | 3 +- lmdeploy/pytorch/models/baichuan.py | 2 +- lmdeploy/pytorch/models/chatglm2.py | 2 +- lmdeploy/pytorch/models/cogvlm.py | 2 +- lmdeploy/pytorch/models/dbrx.py | 2 +- lmdeploy/pytorch/models/deepseek.py | 2 +- lmdeploy/pytorch/models/falcon.py | 2 +- lmdeploy/pytorch/models/gemma.py | 2 +- lmdeploy/pytorch/models/internlm.py | 2 +- lmdeploy/pytorch/models/internlm2.py | 2 +- lmdeploy/pytorch/models/internlm2_ve.py | 338 ++++++++++++++++++ lmdeploy/pytorch/models/internvl.py | 63 +++- lmdeploy/pytorch/models/llama.py | 2 +- lmdeploy/pytorch/models/minicpm3.py | 2 +- lmdeploy/pytorch/models/mistral.py | 2 +- lmdeploy/pytorch/models/mllama.py | 4 +- lmdeploy/pytorch/models/module_map.py | 6 + lmdeploy/pytorch/models/phi3.py | 2 +- lmdeploy/pytorch/models/qwen.py | 2 +- lmdeploy/pytorch/models/qwen2.py | 2 +- lmdeploy/pytorch/models/qwen2_moe.py | 2 +- lmdeploy/pytorch/models/qwen2_vl.py | 2 +- lmdeploy/pytorch/models/starcoder2.py | 2 +- lmdeploy/pytorch/supported_models.py | 2 + 33 files changed, 458 insertions(+), 54 deletions(-) create mode 100644 lmdeploy/pytorch/models/internlm2_ve.py diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 19668fe9e4..20bd3a5f48 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -1,6 +1,6 @@ -## Contributing to InternLM +## Contributing to LMDeploy -Welcome to the InternLM community, all kinds of contributions are welcomed, including but not limited to +Welcome to the LMDeploy community, all kinds of contributions are welcomed, including but not limited to **Fix bug** @@ -56,7 +56,7 @@ upstream git@github.com:InternLM/lmdeploy.git (push) #### 2. Configure pre-commit -You should configure [pre-commit](https://pre-commit.com/#intro) in the local development environment to make sure the code style matches that of InternLM. **Note**: The following code should be executed under the lmdeploy directory. +You should configure [pre-commit](https://pre-commit.com/#intro) in the local development environment to make sure the code style matches that of LMDeploy. **Note**: The following code should be executed under the lmdeploy directory. ```shell pip install -U pre-commit @@ -96,7 +96,7 @@ git checkout -b yhc/refactor_contributing_doc In subsequent development, if the master branch of the local repository is behind the master branch of "upstream", we need to pull the upstream for synchronization, and then execute the above command: ```shell -git pull upstream master +git pull upstream main ``` #### 4. Commit the code and pass the unit test @@ -151,7 +151,7 @@ Find more details about Pull Request description in [pull request guidelines](#p -IternLM will run unit test for the posted Pull Request on different platforms (Linux, Window, Mac), based on different versions of Python, PyTorch, CUDA to make sure the code is correct. We can see the specific test information by clicking `Details` in the above image so that we can modify the code. +LMDeploy will run unit test for the posted Pull Request on different platforms (Linux, Window, Mac), based on different versions of Python, PyTorch, CUDA to make sure the code is correct. We can see the specific test information by clicking `Details` in the above image so that we can modify the code. (3) If the Pull Request passes the CI, then you can wait for the review from other developers. You'll modify the code based on the reviewer's comments, and repeat the steps [4](#4-commit-the-code-and-pass-the-unit-test)-[5](#5-push-the-code-to-remote) until all reviewers approve it. Then, we will merge it ASAP. @@ -163,14 +163,14 @@ If your local branch conflicts with the latest master branch of "upstream", you' ```shell git fetch --all --prune -git rebase upstream/master +git rebase upstream/main ``` or ```shell git fetch --all --prune -git merge upstream/master +git merge upstream/main ``` If you are very good at handling conflicts, then you can use rebase to resolve conflicts, as this will keep your commit logs tidy. If you are not familiar with `rebase`, then you can use `merge` to resolve conflicts. diff --git a/README.md b/README.md index 6ca5fadedd..efbb87a22e 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ ______________________________________________________________________
2024 +- \[2024/11\] Support Mono-InternVL with PyTorch engine - \[2024/10\] PyTorchEngine supports graph mode on ascend platform, doubling the inference speed - \[2024/09\] LMDeploy PyTorchEngine adds support for [Huawei Ascend](./docs/en/get_started/ascend/get_started.md). See supported models [here](docs/en/supported_models/supported_models.md) - \[2024/09\] LMDeploy PyTorchEngine achieves 1.3x faster on Llama3-8B inference by introducing CUDA graph @@ -155,6 +156,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
  • DeepSeek-VL (7B)
  • InternVL-Chat (v1.1-v1.5)
  • InternVL2 (1B-76B)
  • +
  • Mono-InternVL (2B)
  • MiniGeminiLlama (7B)
  • CogVLM-Chat (17B)
  • CogVLM2-Chat (19B)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 663b7b24ab..477fed6f79 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -26,6 +26,7 @@ ______________________________________________________________________
    2024 +- \[2024/11\] PyTorch engine 支持 Mono-InternVL 模型 - \[2024/10\] PyTorchEngine 在 ascend 平台上支持了图模式,推理性能提高了 1 倍 - \[2024/09\] LMDeploy PyTorchEngine 增加了对 [华为 Ascend](docs/zh_cn/get_started/ascend/get_started.md) 的支持。支持的模型请见[这里](docs/zh_cn/supported_models/supported_models.md) - \[2024/09\] 通过引入 CUDA Graph,LMDeploy PyTorchEngine 在 Llama3-8B 推理上实现了 1.3 倍的加速 @@ -156,6 +157,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • DeepSeek-VL (7B)
  • InternVL-Chat (v1.1-v1.5)
  • InternVL2 (1B-76B)
  • +
  • Mono-InternVL (2B)
  • MiniGeminiLlama (7B)
  • CogVLM-Chat (17B)
  • CogVLM2-Chat (19B)
  • diff --git a/docs/en/multi_modal/internvl.md b/docs/en/multi_modal/internvl.md index 24c79357c0..bd33649139 100644 --- a/docs/en/multi_modal/internvl.md +++ b/docs/en/multi_modal/internvl.md @@ -2,12 +2,13 @@ LMDeploy supports the following InternVL series of models, which are detailed in the table below: -| Model | Size | Supported Inference Engine | -| :---------: | :--------: | :------------------------: | -| InternVL | 13B-19B | TurboMind | -| InternVL1.5 | 2B-26B | TurboMind, PyTorch | -| InternVL2 | 1B, 4B | PyTorch | -| InternVL2 | 2B, 8B-76B | TurboMind, PyTorch | +| Model | Size | Supported Inference Engine | +| :-----------: | :--------: | :------------------------: | +| InternVL | 13B-19B | TurboMind | +| InternVL1.5 | 2B-26B | TurboMind, PyTorch | +| InternVL2 | 1B, 4B | PyTorch | +| InternVL2 | 2B, 8B-76B | TurboMind, PyTorch | +| Mono-InternVL | 2B | PyTorch | The next chapter demonstrates how to deploy an InternVL model using LMDeploy, with [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) as an example. diff --git a/docs/en/multi_modal/vl_pipeline.md b/docs/en/multi_modal/vl_pipeline.md index 72eb0b4595..4881b99071 100644 --- a/docs/en/multi_modal/vl_pipeline.md +++ b/docs/en/multi_modal/vl_pipeline.md @@ -9,6 +9,7 @@ Currently, it supports the following models. - [Yi-VL](https://huggingface.co/01-ai/Yi-VL-6B) - [DeepSeek-VL](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) - [InternVL](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5) +- [Mono-InternVL](https://huggingface.co/OpenGVLab/Mono-InternVL-2B) - [MGM](https://huggingface.co/YanweiLi/MGM-7B) - [XComposer](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) - [CogVLM](https://github.com/InternLM/lmdeploy/tree/main/docs/en/multi_modal/cogvlm.md) diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 1f344e78bb..371e4968e0 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -80,6 +80,7 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha | LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | No | - | | InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | | InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | No | - | +| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | No | - | | Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | No | - | | GLM4 | 9B | LLM | Yes | Yes | Yes | No | No | | GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No | @@ -88,6 +89,10 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha | Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | No | - | | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | No | - | +```{note} +* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. +``` + ## PyTorchEngine on Huawei Ascend Platform | Model | Size | Type | FP16/BF16 | W4A16 | diff --git a/docs/zh_cn/multi_modal/internvl.md b/docs/zh_cn/multi_modal/internvl.md index 3d948353a5..e5dae1a89c 100644 --- a/docs/zh_cn/multi_modal/internvl.md +++ b/docs/zh_cn/multi_modal/internvl.md @@ -2,14 +2,15 @@ LMDeploy 支持 InternVL 系列模型,具体如下: -| Model | Size | Supported Inference Engine | -| :---------: | :--------: | :------------------------: | -| InternVL | 13B-19B | TurboMind | -| InternVL1.5 | 2B-26B | TurboMind, PyTorch | -| InternVL2 | 1B, 4B | PyTorch | -| InternVL2 | 2B, 8B-76B | TurboMind, PyTorch | - -本文将以[InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B)为例,演示使用 LMDeploy 部署 InternVL 系列模型的方法 +| Model | Size | Supported Inference Engine | +| :-----------: | :--------: | :------------------------: | +| InternVL | 13B-19B | TurboMind | +| InternVL1.5 | 2B-26B | TurboMind, PyTorch | +| InternVL2 | 1B, 4B | PyTorch | +| InternVL2 | 2B, 8B-76B | TurboMind, PyTorch | +| Mono-InternVL | 2B | PyTorch | + +本文将以[InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B)为例,演示使用 LMDeploy 部署 InternVL 系列模型的方法。 ## 安装 diff --git a/docs/zh_cn/multi_modal/vl_pipeline.md b/docs/zh_cn/multi_modal/vl_pipeline.md index 31533b38f7..570598311a 100644 --- a/docs/zh_cn/multi_modal/vl_pipeline.md +++ b/docs/zh_cn/multi_modal/vl_pipeline.md @@ -9,6 +9,7 @@ LMDeploy 把视觉-语言模型(VLM)复杂的推理过程,抽象为简单 - [Yi-VL](https://huggingface.co/01-ai/Yi-VL-6B) - [DeepSeek-VL](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) - [InternVL](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5) +- [Mono-InternVL](https://huggingface.co/OpenGVLab/Mono-InternVL-2B) - [MGM](https://huggingface.co/YanweiLi/MGM-7B) - [XComposer](https://huggingface.co/internlm/internlm-xcomposer2-vl-7b) - [CogVLM](https://github.com/InternLM/lmdeploy/tree/main/docs/zh_cn/multi_modal/cogvlm.md) diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index ac061cf1ae..7d59a59899 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -80,6 +80,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | No | - | | InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | | InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | No | - | +| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | No | - | | Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | No | - | | GLM4 | 9B | LLM | Yes | Yes | Yes | No | No | | GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No | @@ -88,6 +89,10 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | No | - | | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | No | - | +```{note} +* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. +``` + ## PyTorchEngine 华为昇腾平台 | Model | Size | Type | FP16/BF16 | W4A16 | diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 98f8e373ba..2b3a0a4e1d 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -578,7 +578,8 @@ def match(cls, model_path: str) -> Optional[str]: model_path (str): the model path used for matching. """ path = model_path.lower() - if 'internvl2' in path and 'internvl2-4b' not in path: + if ('internvl2' in path + and 'internvl2-4b' not in path) or 'mono-internvl' in path: return 'internvl2-internlm2' diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py index 6bd18d9e58..583cd19fe9 100644 --- a/lmdeploy/pytorch/models/baichuan.py +++ b/lmdeploy/pytorch/models/baichuan.py @@ -167,7 +167,7 @@ def __init__(self, # build attention layer self.self_attn = BaichuanAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = MLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index efb44b2431..8d7a21a0a6 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -279,7 +279,7 @@ def __init__(self, # build attention layer self.self_attention = SelfAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = MLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index b53c74d95a..6caf10df00 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -263,7 +263,7 @@ def __init__(self, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = VisionExpertMLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py index dd1191625b..e71ff17fe9 100644 --- a/lmdeploy/pytorch/models/dbrx.py +++ b/lmdeploy/pytorch/models/dbrx.py @@ -301,7 +301,7 @@ def __init__(self, dtype=dtype, device=device) - # builf MLP + # build MLP self.ffn = DbrxFFN(config, dtype=dtype, device=device) def forward( diff --git a/lmdeploy/pytorch/models/deepseek.py b/lmdeploy/pytorch/models/deepseek.py index f4e80fb048..5742baeee5 100644 --- a/lmdeploy/pytorch/models/deepseek.py +++ b/lmdeploy/pytorch/models/deepseek.py @@ -250,7 +250,7 @@ def __init__(self, # build attention layer self.self_attn = DeepseekAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = (DeepseekMoE(config, dtype=dtype, device=device) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index e767d29849..8f8659dc5e 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -179,7 +179,7 @@ def __init__(self, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = FalconMLP(config, dtype=dtype, device=device) if not hasattr(config, 'num_ln_in_parallel_attn'): diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index 2d9f85f2ca..450767bda3 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -177,7 +177,7 @@ def __init__(self, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = GemmaMLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/internlm.py b/lmdeploy/pytorch/models/internlm.py index f8869543be..99c622e4ac 100644 --- a/lmdeploy/pytorch/models/internlm.py +++ b/lmdeploy/pytorch/models/internlm.py @@ -161,7 +161,7 @@ def __init__(self, # build attention layer self.self_attn = InternLMAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = InternLMMLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index a87c848e65..6cbc2ccff3 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -160,7 +160,7 @@ def __init__(self, # build attention layer self.attention = InternLM2Attention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.feed_forward = InternLM2MLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/internlm2_ve.py b/lmdeploy/pytorch/models/internlm2_ve.py new file mode 100644 index 0000000000..b1a2329597 --- /dev/null +++ b/lmdeploy/pytorch/models/internlm2_ve.py @@ -0,0 +1,338 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.models.internlm2 import InternLM2Attention, InternLM2MLP +from lmdeploy.pytorch.nn import RMSNorm, RopeType, build_rotary_embedding +from lmdeploy.pytorch.nn.linear import build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .utils.cudagraph import CudaGraphMixin + + +class InternLM2VEDecoderLayer(nn.Module): + """decoder layer with visual expert.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.attention = InternLM2Attention(config, dtype=dtype, device=device) + + # build MLP + self.feed_forward = InternLM2MLP(config, dtype=dtype, device=device) + + # build visual expert + self.feed_forward_ve = InternLM2MLP(config, dtype=dtype, device=device) + + # build input layer norm + self.attention_norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.ffn_norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + vision_embedding_indexing: Optional[torch.Tensor] = None, + text_embedding_indexing: Optional[torch.Tensor] = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + + # Self Attention + hidden_states = self.attention( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + if vision_embedding_indexing is not None: + hidden_states[:, + vision_embedding_indexing, :] = self.feed_forward_ve( + hidden_states[:, vision_embedding_indexing, :]. + reshape(-1, self.hidden_size)).unsqueeze(0) + if text_embedding_indexing is not None: + hidden_states[:, + text_embedding_indexing, :] = self.feed_forward( + hidden_states[:, text_embedding_indexing, :]. + reshape(-1, self.hidden_size)).unsqueeze(0) + else: + hidden_states = self.feed_forward(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class InternLM2VEModel(nn.Module): + """internlm2 model with visual expert.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + quantization_config = getattr(config, 'quantization_config', None) + + self.tok_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + InternLM2VEDecoderLayer(config, + layer_idx, + dtype=dtype, + device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build rotary embedding in Model + rope_scaling = config.rope_scaling + scaling_factor = 1.0 + emb_type = RopeType.LinearScaling + if rope_scaling is not None: + scaling_factor = rope_scaling.get('factor', scaling_factor) + rope_type = rope_scaling['type'] + if rope_type == 'linear': + emb_type = RopeType.LinearScaling + if rope_type == 'dynamic': + emb_type = RopeType.DynamicNTKScaling + else: + raise RuntimeError(f'Unsupported rope type: {rope_type}') + rope_dim = config.hidden_size // config.num_attention_heads + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + self.rotary_emb = build_rotary_embedding( + rope_dim, + rope_max_pos_emb, + rope_base, + scaling_factor, + emb_type=emb_type, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_embedding_indexing: Optional[torch.Tensor] = None, + text_embedding_indexing: Optional[torch.Tensor] = None, + ): + """Rewrite of forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.tok_embeddings(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + vision_embedding_indexing=vision_embedding_indexing, + text_embedding_indexing=text_embedding_indexing, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """get input embeddings.""" + return self.tok_embeddings + + +class InternLM2VEForCausalLM(nn.Module, CudaGraphMixin): + """rewrote model of InternLM2ForCausalLM with visual expert.""" + + packed_modules_mapping = { + 'gate_up_proj': [ + 'w1', + 'w3', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + # build Model + self.model = InternLM2VEModel(config, dtype=dtype, device=device) + # build lm_head + self.output = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + vision_embedding_indexing: Optional[torch.Tensor] = None, + text_embedding_indexing: Optional[torch.Tensor] = None, + **kwargs, + ): + """model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + vision_embedding_indexing=vision_embedding_indexing, + text_embedding_indexing=text_embedding_indexing, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """compute logits of the model output.""" + return self.output(hidden_states) + + def support_cuda_graph( + self, + input_ids: torch.Tensor, + attn_metadata: Any = None, + **kwargs, + ): + """support cudagraph.""" + if not attn_metadata.is_decoding: + return False + seq_lens = input_ids.size(1) + if seq_lens <= 512: + return True + return False + + def get_input_embeddings(self): + """get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.gate_up_proj', '.w1', 0), + ('.gate_up_proj', '.w3', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name + or 'rotary_emb.sin_cached' in name): + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.wqkv' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight, layout='hgd') + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 8981436113..70dd8f2159 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -26,6 +26,15 @@ def __init__(self, dtype=dtype, device=device) + self.llm_arch_name = llm_config.architectures[0] + + # for Mono-InternVL + self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' + if self.is_mono: + assert dtype != torch.float16, ( + 'Currently Mono-InternVL does not support FP16 due to' + 'numerical instability. Please use BF16 instead.') + def forward( self, input_ids: torch.Tensor, @@ -33,13 +42,25 @@ def forward( past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, + vision_embedding_indexing: torch.Tensor = None, + text_embedding_indexing: torch.Tensor = None, **kwargs, ): - return self.language_model.forward(input_ids=input_ids, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - position_ids=position_ids, - attn_metadata=attn_metadata) + if self.is_mono: + return self.language_model.forward( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + attn_metadata=attn_metadata, + vision_embedding_indexing=vision_embedding_indexing, + text_embedding_indexing=text_embedding_indexing) + else: + return self.language_model.forward(input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + attn_metadata=attn_metadata) def get_logits(self, hidden_states: torch.Tensor): """compute logits of the model output.""" @@ -70,13 +91,31 @@ def prepare_inputs_for_generation( vision_embedding_indexing, :] = vision_embeddings.to( inputs_embeds) - return dict( - input_ids=input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - attn_metadata=attn_metadata, - inputs_embeds=inputs_embeds, - ) + if self.is_mono and vision_embedding_indexing is not None: + all_indices = torch.arange(input_ids.shape[1]).to(input_ids) + text_embedding_indexing = all_indices[ + ~torch.isin(all_indices, vision_embedding_indexing)] + if vision_embedding_indexing.numel() == 0: + vision_embedding_indexing = None + if text_embedding_indexing.numel() == 0: + text_embedding_indexing = None + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + vision_embedding_indexing=vision_embedding_indexing, + text_embedding_indexing=text_embedding_indexing, + ) + else: + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 525c8e3d34..f38c5ef02b 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -163,7 +163,7 @@ def __init__(self, # build attention layer self.self_attn = LlamaAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = LlamaMLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/minicpm3.py b/lmdeploy/pytorch/models/minicpm3.py index 56a1c4edf1..72a2b8a045 100644 --- a/lmdeploy/pytorch/models/minicpm3.py +++ b/lmdeploy/pytorch/models/minicpm3.py @@ -237,7 +237,7 @@ def __init__(self, # build attention layer self.self_attn = MiniCPMAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = MiniCPMMLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py index 4c369b716b..04af4c8526 100644 --- a/lmdeploy/pytorch/models/mistral.py +++ b/lmdeploy/pytorch/models/mistral.py @@ -162,7 +162,7 @@ def __init__(self, # build attention layer self.self_attn = MistralAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = MistralMLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/mllama.py b/lmdeploy/pytorch/models/mllama.py index a16abd8b91..2596fe5299 100644 --- a/lmdeploy/pytorch/models/mllama.py +++ b/lmdeploy/pytorch/models/mllama.py @@ -267,7 +267,7 @@ def __init__(self, # build attention layer self.self_attn = LlamaAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = LlamaMLP(config, dtype=dtype, device=device) # build input layer norm @@ -336,7 +336,7 @@ def __init__(self, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = LlamaMLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index bc6385d8b2..e6b5f6e29e 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -149,6 +149,12 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internvl.InternVLChatModel' }) +# mono-internvl +MODULE_MAP.update({ + 'InternLM2VEForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2_ve.InternLM2VEForCausalLM', +}) + # phi3 vision MODULE_MAP.update({ 'Phi3VForCausalLM': diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index a2859e3e3e..f9477fdab8 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -165,7 +165,7 @@ def __init__(self, # build attention layer self.self_attn = Phi3Attention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = Phi3MLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/qwen.py b/lmdeploy/pytorch/models/qwen.py index 50b9fd4ee8..bf856461a3 100644 --- a/lmdeploy/pytorch/models/qwen.py +++ b/lmdeploy/pytorch/models/qwen.py @@ -174,7 +174,7 @@ def __init__(self, # build attention layer self.attn = QWenAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = QWenMLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index de6a7a58e1..82be75e167 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -163,7 +163,7 @@ def __init__(self, # build attention layer self.self_attn = Qwen2Attention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = Qwen2MLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index fdaff8e0cc..1aff14483a 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -258,7 +258,7 @@ def __init__(self, # build attention layer self.self_attn = Qwen2MoeAttention(config, dtype=dtype, device=device) - # builf MLP + # build MLP if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0) and ((layer_idx + 1) % config.decoder_sparse_step == 0): diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index 1a1dc1e1da..b10baaa4d5 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -192,7 +192,7 @@ def __init__(self, # build attention layer self.self_attn = Qwen2Attention(config, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = Qwen2MLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/models/starcoder2.py b/lmdeploy/pytorch/models/starcoder2.py index 7498df606f..4a6b175ca3 100644 --- a/lmdeploy/pytorch/models/starcoder2.py +++ b/lmdeploy/pytorch/models/starcoder2.py @@ -168,7 +168,7 @@ def __init__(self, dtype=dtype, device=device) - # builf MLP + # build MLP self.mlp = Starcoder2MLP(config, dtype=dtype, device=device) # build input layer norm diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py index 3a5baf8fc6..21418188dd 100644 --- a/lmdeploy/pytorch/supported_models.py +++ b/lmdeploy/pytorch/supported_models.py @@ -62,6 +62,8 @@ DeepseekV2ForCausalLM=True, # internvl InternVLChatModel=True, + # mono-internvl + InternLM2VEForCausalLM=True, # gemma2 Gemma2ForCausalLM=True, # phi3.5-moe