diff --git a/tests/models/registry.py b/tests/models/registry.py index 865e90b3f8b0e..201a6b0294917 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -115,6 +115,9 @@ class _HfExamplesInfo: "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), + "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-115B", + is_available_online=False, + trust_remote_code=True), "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", is_available_online=False, trust_remote_code=True), diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b28210e26629a..4462f6ed55a9c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -119,6 +119,7 @@ "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501 + "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index 0ea79a1717712..205e57c6b46b8 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -23,6 +23,7 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -100,6 +101,7 @@ class TeleChat2ForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(LlamaForCausalLM, self).__init__() config = vllm_config.model_config.hf_config + pooler_config = vllm_config.model_config.pooler_config quant_config = vllm_config.quant_config config.intermediate_size = config.ffn_hidden_size config.hidden_act = "silu" @@ -116,6 +118,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.STEP, + normalize=False, + softmax=False) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: