Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support telechat2 #10311

Merged
merged 27 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ Text Generation
- :code:`upstage/solar-pro-preview-instruct`, etc.
- ✅︎
- ✅︎
* - :code:`TeleChat2ForCausalLM`
- TeleChat2
- :code:`TeleAI/TeleChat2-3B`, :code:`TeleAI/TeleChat2-7B`, :code:`TeleAI/TeleChat2-35B`, etc.
-
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
- ✅︎
* - :code:`XverseForCausalLM`
- XVERSE
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
Expand Down
140 changes: 140 additions & 0 deletions vllm/model_executor/models/telechat2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Set, Tuple

import torch

from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel

from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix


class TeleChat2Model(LlamaModel):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# 1. Initialize the LlamaModel with bias
vllm_config.model_config.hf_config.bias = True
vllm_config.model_config.hf_config.mlp_bias = True
super().__init__(vllm_config=vllm_config, prefix=prefix)
# 2. Remove the bias from the qkv_proj and gate_up_proj based on config
# Telechat2's gate_up_proj and qkv_proj don't have bias
# see: https://github.com/vllm-project/vllm/pull/10311#issuecomment-2490297566
for layer in self.layers:
layer.self_attn.qkv_proj.bias = layer.mlp.gate_up_proj.bias = None
layer.self_attn.qkv_proj.skip_bias_add = True
layer.mlp.gate_up_proj.skip_bias_add = True

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
('gate_up_proj', 'gate_proj', 0),
('gate_up_proj', 'up_proj', 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
total_num_heads = self.config.n_head
head_dim = self.config.hidden_size // total_num_heads
for name, loaded_weight in weights:
if "self_attn.key_value" in name:
k_weight = []
v_weight = []
for i in range(total_num_heads):
start = i * head_dim * 2
k_weight.append(loaded_weight[start:start + head_dim, :])
v_weight.append(loaded_weight[start + head_dim:start +
2 * head_dim:])
k_weight = torch.cat(k_weight, dim=0)
v_weight = torch.cat(v_weight, dim=0)
name = name.replace("key_value", "qkv_proj")
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, k_weight, "k")
weight_loader(param, v_weight, "v")
elif "query" in name:
name = name.replace("query", "qkv_proj")
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, "q")
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class TeleChat2ForCausalLM(LlamaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(LlamaForCausalLM, self).__init__()
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
config.intermediate_size = config.ffn_hidden_size
config.hidden_act = "silu"
config.rms_norm_eps = config.layer_norm_epsilon
config.tie_word_embeddings = False
self.config = config
self.model = TeleChat2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"transformer.": "model.",
},
orig_to_new_substr={
".h.": ".layers.",
".self_attention.": ".self_attn.",
".word_embeddings.": ".embed_tokens.",
".dense.": ".o_proj.",
".ln_f.": ".norm.",
},
)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)