diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4c096acdf2035..3da99bcbee9ae 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -29,7 +29,8 @@ MLPSpeculatorConfig, MPTConfig, NemotronConfig, NVLM_D_Config, Olmo2Config, RWConfig, - SolarConfig, UltravoxConfig) + SolarConfig, Telechat2Config, + UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import resolve_obj_by_qualname @@ -64,6 +65,7 @@ "NVLM_D": NVLM_D_Config, "olmo2": Olmo2Config, "solar": SolarConfig, + "telechat": Telechat2Config, "ultravox": UltravoxConfig, **_CONFIG_REGISTRY_OVERRIDE_HF } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 4c721001d8434..c24433cd436b4 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -17,6 +17,7 @@ from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config from vllm.transformers_utils.configs.olmo2 import Olmo2Config from vllm.transformers_utils.configs.solar import SolarConfig +from vllm.transformers_utils.configs.telechat2 import Telechat2Config from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -36,5 +37,6 @@ "NVLM_D_Config", "Olmo2Config", "SolarConfig", + "Telechat2Config", "UltravoxConfig", ] \ No newline at end of file diff --git a/vllm/transformers_utils/configs/telechat2.py b/vllm/transformers_utils/configs/telechat2.py new file mode 100644 index 0000000000000..b292fe8281db7 --- /dev/null +++ b/vllm/transformers_utils/configs/telechat2.py @@ -0,0 +1,76 @@ +# adapted from https://www.modelscope.cn/models/TeleAI/TeleChat2-3B/resolve/master/configuration_telechat2.py +""" Telechat configuration compatible with LlamaConfig. """ + +from transformers.configuration_utils import PretrainedConfig + + +class Telechat2Config(PretrainedConfig): + """ + Args: + vocab_size (`int`, *optional*, defaults to 160256): Vocabulary size of the Telechat model. + hidden_size (`int`, *optional*, defaults to 4096): Dimensionality of the embeddings and hidden states. + ffn_hidden_size (`int`, *optional*, defaults to 12288): Dimensionality of the feed-forward hidden states. + n_layer (`int`, *optional*, defaults to 30): Number of hidden layers in the Transformer + n_head (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`): If enabled, use the layer norm of the hidden states as the residual in the transformer blocks + hidden_dropout (`float`, *optional*, defaults to 0.0): Dropout rate of the dropout function on the bias dropout. + attention_dropout (`float`, *optional*, defaults to 0.0): Dropout rate applied to the attention probs + use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions. + training_seqlen (`int`, *optional*, defaults to 8192): Sequence length during last finetuning. + logn (`bool`, *optional*, defaults to `True`): Whether or not to use logN during extrapolation. + embed_layernorm (`bool`, *optional*, defaults to `True`): Whether or not to use embedding layernorm. + + """ + + model_type = "telechat" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "intermediate_size": "ffn_hidden_size", + "rms_norm_eps": "layer_norm_epsilon" + } + + def __init__( + self, + vocab_size=160256, + hidden_size=4096, + n_layer=30, + n_head=32, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.0, + attention_dropout=0.0, + ffn_hidden_size=12288, + training_seqlen = 8192, + logn = True, + embed_layernorm = False, + hidden_act="silu", + **kwargs, + ): + self.vocab_size = vocab_size + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.logn = logn + self.training_seqlen = training_seqlen + self.embed_layernorm = embed_layernorm + self.num_key_value_heads= kwargs.pop("num_key_value_heads", None) + self.ffn_hidden_size = ffn_hidden_size + self.hidden_act = hidden_act + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)