diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index 0b03d5542d1e9..ff536115f7b1e 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -17,133 +17,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, Optional, Set, Tuple +from typing import Iterable, Set, Tuple import torch -from transformers import PretrainedConfig -from vllm.config import CacheConfig, VllmConfig -from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import (LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, LlamaMLP, - LlamaModel) +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel -from .utils import AutoWeightsLoader, WeightsMapper, make_layers, maybe_prefix - - -class TeleChat2MLP(LlamaMLP): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - prefix: str = "", - ) -> None: - super().__init__(hidden_size, intermediate_size, hidden_act, - quant_config, bias, prefix) - self.down_proj = RowParallelLinear( - input_size=intermediate_size, - output_size=hidden_size, - bias=True, - quant_config=quant_config, - ) - - -class TeleChat2Attention(LlamaAttention): - - def __init__(self, - config, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = "") -> None: - super().__init__(config, hidden_size, num_heads, num_kv_heads, - rope_theta, rope_scaling, max_position_embeddings, - quant_config, bias, cache_config, prefix) - self.o_proj = RowParallelLinear( - input_size=hidden_size, - output_size=hidden_size, - bias=True, - quant_config=quant_config, - input_is_parallel=True, - prefix=f"{prefix}.dense_proj", - ) - - -class TeleChat2DecoderLayer(LlamaDecoderLayer): - - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, cache_config, quant_config, prefix) - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): - rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) - self.self_attn = TeleChat2Attention( - config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - bias=attention_bias, - cache_config=cache_config, - ) - self.mlp = TeleChat2MLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - bias=getattr(config, "mlp_bias", False), - ) +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class TeleChat2Model(LlamaModel): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # 1. Initialize the LlamaModel with bias + vllm_config.model_config.hf_config.bias = True + vllm_config.model_config.hf_config.mlp_bias = True super().__init__(vllm_config=vllm_config, prefix=prefix) - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: TeleChat2DecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layers"), - prefix=f"{prefix}.layers", - ) + # 2. Remove the bias from the qkv_proj and gate_up_proj based on config + # FIXME: Handle qkv_bias etc + for layer in self.layers: + layer.self_attn.qkv_proj.bias = layer.mlp.gate_up_proj.bias = None + layer.self_attn.qkv_proj.skip_bias_add = True + layer.mlp.gate_up_proj.skip_bias_add = True def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: @@ -171,13 +71,11 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = param.weight_loader weight_loader(param, k_weight, "k") weight_loader(param, v_weight, "v") - loaded_params.add(name) elif "query" in name: name = name.replace("query", "qkv_proj") param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, "q") - loaded_params.add(name) else: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -192,6 +90,7 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) return loaded_params @@ -207,7 +106,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.tie_word_embeddings = False self.config = config self.model = TeleChat2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "transformer")) + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, @@ -237,4 +136,4 @@ def load_weights(self, weights: Iterable[Tuple[str, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) \ No newline at end of file