Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangw2 committed Nov 26, 2024
1 parent b492509 commit 6712ef1
Showing 1 changed file with 17 additions and 118 deletions.
135 changes: 17 additions & 118 deletions vllm/model_executor/models/telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,133 +17,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, Optional, Set, Tuple
from typing import Iterable, Set, Tuple

import torch
from transformers import PretrainedConfig

from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM, LlamaMLP,
LlamaModel)
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel

from .utils import AutoWeightsLoader, WeightsMapper, make_layers, maybe_prefix


class TeleChat2MLP(LlamaMLP):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__(hidden_size, intermediate_size, hidden_act,
quant_config, bias, prefix)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
)


class TeleChat2Attention(LlamaAttention):

def __init__(self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "") -> None:
super().__init__(config, hidden_size, num_heads, num_kv_heads,
rope_theta, rope_scaling, max_position_embeddings,
quant_config, bias, cache_config, prefix)
self.o_proj = RowParallelLinear(
input_size=hidden_size,
output_size=hidden_size,
bias=True,
quant_config=quant_config,
input_is_parallel=True,
prefix=f"{prefix}.dense_proj",
)


class TeleChat2DecoderLayer(LlamaDecoderLayer):

def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = TeleChat2Attention(
config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
)
self.mlp = TeleChat2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
)
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix


class TeleChat2Model(LlamaModel):

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# 1. Initialize the LlamaModel with bias
vllm_config.model_config.hf_config.bias = True
vllm_config.model_config.hf_config.mlp_bias = True
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: TeleChat2DecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers"),
prefix=f"{prefix}.layers",
)
# 2. Remove the bias from the qkv_proj and gate_up_proj based on config
# FIXME: Handle qkv_bias etc
for layer in self.layers:
layer.self_attn.qkv_proj.bias = layer.mlp.gate_up_proj.bias = None
layer.self_attn.qkv_proj.skip_bias_add = True
layer.mlp.gate_up_proj.skip_bias_add = True

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
Expand Down Expand Up @@ -171,13 +71,11 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader = param.weight_loader
weight_loader(param, k_weight, "k")
weight_loader(param, v_weight, "v")
loaded_params.add(name)
elif "query" in name:
name = name.replace("query", "qkv_proj")
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, "q")
loaded_params.add(name)
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
Expand All @@ -192,6 +90,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


Expand All @@ -207,7 +106,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.tie_word_embeddings = False
self.config = config
self.model = TeleChat2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "transformer"))
prefix=maybe_prefix(prefix, "model"))

self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
Expand Down Expand Up @@ -237,4 +136,4 @@ def load_weights(self, weights: Iterable[Tuple[str,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)

0 comments on commit 6712ef1

Please sign in to comment.