From e2552622389da8b6af3fb82b9eb66a3d6ba2fd3d Mon Sep 17 00:00:00 2001 From: shunxing12345 <168084185+shunxing12345@users.noreply.github.com> Date: Wed, 27 Nov 2024 19:32:35 +0800 Subject: [PATCH] [Model] Support telechat2 (#10311) Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: xiangw2 Co-authored-by: Isotr0py <2037008807@qq.com> Signed-off-by: Andrew Feldman --- docs/source/models/supported_models.rst | 5 + tests/models/registry.py | 2 + vllm/model_executor/models/llama.py | 6 +- vllm/model_executor/models/registry.py | 2 + vllm/model_executor/models/telechat2.py | 131 +++++++++++++++++++ vllm/transformers_utils/config.py | 4 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/telechat2.py | 61 +++++++++ 8 files changed, 210 insertions(+), 3 deletions(-) create mode 100644 vllm/model_executor/models/telechat2.py create mode 100644 vllm/transformers_utils/configs/telechat2.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index b5cbe6915d581..c5fbb30b24e28 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -309,6 +309,11 @@ Text Generation - :code:`upstage/solar-pro-preview-instruct`, etc. - ✅︎ - ✅︎ + * - :code:`TeleChat2ForCausalLM` + - TeleChat2 + - :code:`TeleAI/TeleChat2-3B`, :code:`TeleAI/TeleChat2-7B`, :code:`TeleAI/TeleChat2-35B`, etc. + - ✅︎ + - ✅︎ * - :code:`XverseForCausalLM` - XVERSE - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. diff --git a/tests/models/registry.py b/tests/models/registry.py index 865e90b3f8b0e..a93bfe907e0d7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -115,6 +115,8 @@ class _HfExamplesInfo: "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), + "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", + trust_remote_code=True), "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", is_available_online=False, trust_remote_code=True), diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 7cc5547b4a4d5..fffb3fe53b94c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -501,8 +501,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.lora_config = lora_config - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = self._init_model(vllm_config=vllm_config, prefix=prefix) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -539,6 +538,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): normalize=False, softmax=False) + def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): + return LlamaModel(vllm_config=vllm_config, prefix=prefix) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f5a02a5b25ca2..4462f6ed55a9c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -91,6 +91,7 @@ "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), + "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), @@ -118,6 +119,7 @@ "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501 + "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py new file mode 100644 index 0000000000000..39c9103527f01 --- /dev/null +++ b/vllm/model_executor/models/telechat2.py @@ -0,0 +1,131 @@ +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Set, Tuple + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel + +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter) + + +class TeleChat2Model(LlamaModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # 1. Initialize the LlamaModel with bias + vllm_config.model_config.hf_config.bias = True + vllm_config.model_config.hf_config.mlp_bias = True + super().__init__(vllm_config=vllm_config, prefix=prefix) + # 2. Remove the bias from the qkv_proj and gate_up_proj based on config + # Telechat2's gate_up_proj and qkv_proj don't have bias + # see: https://github.com/vllm-project/vllm/pull/10311#issuecomment-2490297566 + for layer in self.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.qkv_proj.bias = None + layer.self_attn.qkv_proj.skip_bias_add = True + layer.mlp.gate_up_proj.bias = None + layer.mlp.gate_up_proj.skip_bias_add = True + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + ('gate_up_proj', 'gate_proj', 0), + ('gate_up_proj', 'up_proj', 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + total_num_heads = self.config.n_head + head_dim = self.config.hidden_size // total_num_heads + for name, loaded_weight in weights: + if "self_attn.key_value" in name: + k_weight = [] + v_weight = [] + for i in range(total_num_heads): + start = i * head_dim * 2 + k_weight.append(loaded_weight[start:start + head_dim, :]) + v_weight.append(loaded_weight[start + head_dim:start + + 2 * head_dim:]) + k_weight = torch.cat(k_weight, dim=0) + v_weight = torch.cat(v_weight, dim=0) + name = name.replace("key_value", "qkv_proj") + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, k_weight, "k") + weight_loader(param, v_weight, "v") + elif "query" in name: + name = name.replace("query", "qkv_proj") + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, "q") + else: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class TeleChat2ForCausalLM(LlamaForCausalLM): + + def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): + return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "transformer.": "model.", + }, + orig_to_new_substr={ + ".h.": ".layers.", + ".self_attention.": ".self_attn.", + ".word_embeddings.": ".embed_tokens.", + ".dense.": ".o_proj.", + ".ln_f.": ".norm.", + }, + ) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4c096acdf2035..3da99bcbee9ae 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -29,7 +29,8 @@ MLPSpeculatorConfig, MPTConfig, NemotronConfig, NVLM_D_Config, Olmo2Config, RWConfig, - SolarConfig, UltravoxConfig) + SolarConfig, Telechat2Config, + UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname @@ -64,6 +65,7 @@ "NVLM_D": NVLM_D_Config, "olmo2": Olmo2Config, "solar": SolarConfig, + "telechat": Telechat2Config, "ultravox": UltravoxConfig, **_CONFIG_REGISTRY_OVERRIDE_HF } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 4c721001d8434..c24433cd436b4 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -17,6 +17,7 @@ from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config from vllm.transformers_utils.configs.olmo2 import Olmo2Config from vllm.transformers_utils.configs.solar import SolarConfig +from vllm.transformers_utils.configs.telechat2 import Telechat2Config from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -36,5 +37,6 @@ "NVLM_D_Config", "Olmo2Config", "SolarConfig", + "Telechat2Config", "UltravoxConfig", ] \ No newline at end of file diff --git a/vllm/transformers_utils/configs/telechat2.py b/vllm/transformers_utils/configs/telechat2.py new file mode 100644 index 0000000000000..eb6f5a059169f --- /dev/null +++ b/vllm/transformers_utils/configs/telechat2.py @@ -0,0 +1,61 @@ +# adapted from https://www.modelscope.cn/models/TeleAI/TeleChat2-3B/resolve/master/configuration_telechat2.py +""" Telechat configuration compatible with LlamaConfig. """ + +from transformers.configuration_utils import PretrainedConfig + + +class Telechat2Config(PretrainedConfig): + + model_type = "telechat" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "intermediate_size": "ffn_hidden_size", + "rms_norm_eps": "layer_norm_epsilon" + } + + def __init__( + self, + vocab_size=160256, + hidden_size=4096, + n_layer=30, + n_head=32, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.0, + attention_dropout=0.0, + ffn_hidden_size=12288, + training_seqlen=8192, + logn=True, + embed_layernorm=False, + hidden_act="silu", + **kwargs, + ): + self.vocab_size = vocab_size + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.logn = logn + self.training_seqlen = training_seqlen + self.embed_layernorm = embed_layernorm + self.num_key_value_heads = kwargs.pop("num_key_value_heads", None) + self.ffn_hidden_size = ffn_hidden_size + self.hidden_act = hidden_act + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs)