Skip to content

Commit

Permalink
[Model] Support telechat2 (vllm-project#10311)
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: xiangw2 <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
3 people authored and weilong.yu committed Dec 13, 2024
1 parent 8341a04 commit ac4fd13
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ Text Generation
- :code:`upstage/solar-pro-preview-instruct`, etc.
- ✅︎
- ✅︎
* - :code:`TeleChat2ForCausalLM`
- TeleChat2
- :code:`TeleAI/TeleChat2-3B`, :code:`TeleAI/TeleChat2-7B`, :code:`TeleAI/TeleChat2-35B`, etc.
- ✅︎
- ✅︎
* - :code:`XverseForCausalLM`
- XVERSE
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class _HfExamplesInfo:
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
trust_remote_code=True),
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False,
trust_remote_code=True),
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.lora_config = lora_config

self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
Expand Down Expand Up @@ -539,6 +538,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
normalize=False,
softmax=False)

def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
Expand Down Expand Up @@ -118,6 +119,7 @@
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
Expand Down
131 changes: 131 additions & 0 deletions vllm/model_executor/models/telechat2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Set, Tuple

import torch

from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel

from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter)


class TeleChat2Model(LlamaModel):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# 1. Initialize the LlamaModel with bias
vllm_config.model_config.hf_config.bias = True
vllm_config.model_config.hf_config.mlp_bias = True
super().__init__(vllm_config=vllm_config, prefix=prefix)
# 2. Remove the bias from the qkv_proj and gate_up_proj based on config
# Telechat2's gate_up_proj and qkv_proj don't have bias
# see: https://github.com/vllm-project/vllm/pull/10311#issuecomment-2490297566
for layer in self.layers:
if not isinstance(layer, PPMissingLayer):
layer.self_attn.qkv_proj.bias = None
layer.self_attn.qkv_proj.skip_bias_add = True
layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
('gate_up_proj', 'gate_proj', 0),
('gate_up_proj', 'up_proj', 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
total_num_heads = self.config.n_head
head_dim = self.config.hidden_size // total_num_heads
for name, loaded_weight in weights:
if "self_attn.key_value" in name:
k_weight = []
v_weight = []
for i in range(total_num_heads):
start = i * head_dim * 2
k_weight.append(loaded_weight[start:start + head_dim, :])
v_weight.append(loaded_weight[start + head_dim:start +
2 * head_dim:])
k_weight = torch.cat(k_weight, dim=0)
v_weight = torch.cat(v_weight, dim=0)
name = name.replace("key_value", "qkv_proj")
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, k_weight, "k")
weight_loader(param, v_weight, "v")
elif "query" in name:
name = name.replace("query", "qkv_proj")
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, "q")
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class TeleChat2ForCausalLM(LlamaForCausalLM):

def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"transformer.": "model.",
},
orig_to_new_substr={
".h.": ".layers.",
".self_attention.": ".self_attn.",
".word_embeddings.": ".embed_tokens.",
".dense.": ".o_proj.",
".ln_f.": ".norm.",
},
)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
4 changes: 3 additions & 1 deletion vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
Olmo2Config, RWConfig,
SolarConfig, UltravoxConfig)
SolarConfig, Telechat2Config,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
Expand Down Expand Up @@ -64,6 +65,7 @@
"NVLM_D": NVLM_D_Config,
"olmo2": Olmo2Config,
"solar": SolarConfig,
"telechat": Telechat2Config,
"ultravox": UltravoxConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

__all__ = [
Expand All @@ -36,5 +37,6 @@
"NVLM_D_Config",
"Olmo2Config",
"SolarConfig",
"Telechat2Config",
"UltravoxConfig",
]
61 changes: 61 additions & 0 deletions vllm/transformers_utils/configs/telechat2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# adapted from https://www.modelscope.cn/models/TeleAI/TeleChat2-3B/resolve/master/configuration_telechat2.py
""" Telechat configuration compatible with LlamaConfig. """

from transformers.configuration_utils import PretrainedConfig


class Telechat2Config(PretrainedConfig):

model_type = "telechat"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
"intermediate_size": "ffn_hidden_size",
"rms_norm_eps": "layer_norm_epsilon"
}

def __init__(
self,
vocab_size=160256,
hidden_size=4096,
n_layer=30,
n_head=32,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.0,
attention_dropout=0.0,
ffn_hidden_size=12288,
training_seqlen=8192,
logn=True,
embed_layernorm=False,
hidden_act="silu",
**kwargs,
):
self.vocab_size = vocab_size
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm)
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.logn = logn
self.training_seqlen = training_seqlen
self.embed_layernorm = embed_layernorm
self.num_key_value_heads = kwargs.pop("num_key_value_heads", None)
self.ffn_hidden_size = ffn_hidden_size
self.hidden_act = hidden_act
super().__init__(bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs)

0 comments on commit ac4fd13

Please sign in to comment.