Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support telechat2 #10311

Merged
merged 27 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ Text Generation
- :code:`upstage/solar-pro-preview-instruct`, etc.
- ✅︎
- ✅︎
* - :code:`TeleChat2ForCausalLM`
- TeleChat2
- :code:`TeleAI/TeleChat2-3B`, :code:`TeleAI/TeleChat2-7B`, :code:`TeleAI/TeleChat2-35B`, etc.
- ✅︎
- ✅︎
* - :code:`XverseForCausalLM`
- XVERSE
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class _HfExamplesInfo:
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
trust_remote_code=True),
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False,
trust_remote_code=True),
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.lora_config = lora_config

self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
Expand Down Expand Up @@ -548,6 +547,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
normalize=False,
softmax=False)

def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
Expand Down Expand Up @@ -118,6 +119,7 @@
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
Expand Down
131 changes: 131 additions & 0 deletions vllm/model_executor/models/telechat2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Set, Tuple

import torch

from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel

from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter)


class TeleChat2Model(LlamaModel):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# 1. Initialize the LlamaModel with bias
vllm_config.model_config.hf_config.bias = True
vllm_config.model_config.hf_config.mlp_bias = True
super().__init__(vllm_config=vllm_config, prefix=prefix)
# 2. Remove the bias from the qkv_proj and gate_up_proj based on config
# Telechat2's gate_up_proj and qkv_proj don't have bias
# see: https://github.com/vllm-project/vllm/pull/10311#issuecomment-2490297566
for layer in self.layers:
if not isinstance(layer, PPMissingLayer):
layer.self_attn.qkv_proj.bias = None
layer.self_attn.qkv_proj.skip_bias_add = True
layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
('gate_up_proj', 'gate_proj', 0),
('gate_up_proj', 'up_proj', 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
total_num_heads = self.config.n_head
head_dim = self.config.hidden_size // total_num_heads
for name, loaded_weight in weights:
if "self_attn.key_value" in name:
k_weight = []
v_weight = []
for i in range(total_num_heads):
start = i * head_dim * 2
k_weight.append(loaded_weight[start:start + head_dim, :])
v_weight.append(loaded_weight[start + head_dim:start +
2 * head_dim:])
k_weight = torch.cat(k_weight, dim=0)
v_weight = torch.cat(v_weight, dim=0)
name = name.replace("key_value", "qkv_proj")
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, k_weight, "k")
weight_loader(param, v_weight, "v")
elif "query" in name:
name = name.replace("query", "qkv_proj")
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, "q")
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class TeleChat2ForCausalLM(LlamaForCausalLM):

def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"transformer.": "model.",
},
orig_to_new_substr={
".h.": ".layers.",
".self_attention.": ".self_attn.",
".word_embeddings.": ".embed_tokens.",
".dense.": ".o_proj.",
".ln_f.": ".norm.",
},
)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
4 changes: 3 additions & 1 deletion vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
Olmo2Config, RWConfig,
SolarConfig, UltravoxConfig)
SolarConfig, Telechat2Config,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
Expand Down Expand Up @@ -64,6 +65,7 @@
"NVLM_D": NVLM_D_Config,
"olmo2": Olmo2Config,
"solar": SolarConfig,
"telechat": Telechat2Config,
"ultravox": UltravoxConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

__all__ = [
Expand All @@ -36,5 +37,6 @@
"NVLM_D_Config",
"Olmo2Config",
"SolarConfig",
"Telechat2Config",
"UltravoxConfig",
]
61 changes: 61 additions & 0 deletions vllm/transformers_utils/configs/telechat2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# adapted from https://www.modelscope.cn/models/TeleAI/TeleChat2-3B/resolve/master/configuration_telechat2.py
""" Telechat configuration compatible with LlamaConfig. """

from transformers.configuration_utils import PretrainedConfig


class Telechat2Config(PretrainedConfig):

model_type = "telechat"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
"intermediate_size": "ffn_hidden_size",
"rms_norm_eps": "layer_norm_epsilon"
}

def __init__(
self,
vocab_size=160256,
hidden_size=4096,
n_layer=30,
n_head=32,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.0,
attention_dropout=0.0,
ffn_hidden_size=12288,
training_seqlen=8192,
logn=True,
embed_layernorm=False,
hidden_act="silu",
**kwargs,
):
self.vocab_size = vocab_size
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm)
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.logn = logn
self.training_seqlen = training_seqlen
self.embed_layernorm = embed_layernorm
self.num_key_value_heads = kwargs.pop("num_key_value_heads", None)
self.ffn_hidden_size = ffn_hidden_size
self.hidden_act = hidden_act
super().__init__(bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs)