Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py committed Nov 27, 2024
1 parent 1ea77b5 commit e5757a2
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 34 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ Text Generation
* - :code:`TeleChat2ForCausalLM`
- TeleChat2
- :code:`TeleAI/TeleChat2-3B`, :code:`TeleAI/TeleChat2-7B`, :code:`TeleAI/TeleChat2-35B`, etc.
-
- ✅︎
- ✅︎
* - :code:`XverseForCausalLM`
- XVERSE
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.lora_config = lora_config

self.model = LlamaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = self._init_model(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
Expand Down Expand Up @@ -548,6 +547,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
normalize=False,
softmax=False)

def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down
35 changes: 4 additions & 31 deletions vllm/model_executor/models/telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,10 @@
import torch

from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel

from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper


class TeleChat2Model(LlamaModel):
Expand Down Expand Up @@ -98,31 +94,8 @@ def load_weights(self, weights: Iterable[Tuple[str,

class TeleChat2ForCausalLM(LlamaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(LlamaForCausalLM, self).__init__()
config = vllm_config.model_config.hf_config
pooler_config = vllm_config.model_config.pooler_config
quant_config = vllm_config.quant_config
config.intermediate_size = config.ffn_hidden_size
config.hidden_act = "silu"
config.rms_norm_eps = config.layer_norm_epsilon
config.tie_word_embeddings = False
self.config = config
self.model = TeleChat2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=False)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
Expand All @@ -144,4 +117,4 @@ def load_weights(self, weights: Iterable[Tuple[str,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)

0 comments on commit e5757a2

Please sign in to comment.