From f89d18ff74e48f97c76afbab31956218d2486e36 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 10 Nov 2024 22:41:46 -0800 Subject: [PATCH] [6/N] pass whole config to inner model (#10205) Signed-off-by: youkaichao --- vllm/model_executor/models/arctic.py | 23 ++++---- vllm/model_executor/models/baichuan.py | 48 ++++++++-------- vllm/model_executor/models/bart.py | 23 ++++---- vllm/model_executor/models/bert.py | 25 ++++----- vllm/model_executor/models/blip2.py | 10 +--- vllm/model_executor/models/bloom.py | 21 +++---- vllm/model_executor/models/chameleon.py | 27 ++++----- vllm/model_executor/models/chatglm.py | 19 ++++--- vllm/model_executor/models/commandr.py | 33 +++++------ vllm/model_executor/models/dbrx.py | 21 +++---- vllm/model_executor/models/decilm.py | 6 +- vllm/model_executor/models/deepseek.py | 26 ++++----- vllm/model_executor/models/deepseek_v2.py | 29 ++++------ vllm/model_executor/models/eagle.py | 5 +- vllm/model_executor/models/exaone.py | 34 +++++------ vllm/model_executor/models/falcon.py | 21 +++---- vllm/model_executor/models/florence2.py | 38 ++++++------- vllm/model_executor/models/gemma.py | 24 +++----- vllm/model_executor/models/gemma2.py | 28 +++------- vllm/model_executor/models/gpt2.py | 24 ++++---- vllm/model_executor/models/gpt_bigcode.py | 22 ++++---- vllm/model_executor/models/gpt_j.py | 21 +++---- vllm/model_executor/models/gpt_neox.py | 20 +++---- vllm/model_executor/models/granite.py | 34 +++++------ vllm/model_executor/models/granitemoe.py | 33 ++++------- vllm/model_executor/models/idefics3.py | 29 ++++------ vllm/model_executor/models/internlm2.py | 24 +++----- vllm/model_executor/models/internlm2_ve.py | 23 +++----- vllm/model_executor/models/internvl.py | 6 +- vllm/model_executor/models/jais.py | 21 +++---- vllm/model_executor/models/jamba.py | 30 ++++------ vllm/model_executor/models/llama.py | 44 ++++----------- vllm/model_executor/models/llava.py | 6 +- vllm/model_executor/models/llava_next.py | 6 +- .../model_executor/models/llava_next_video.py | 6 +- vllm/model_executor/models/llava_onevision.py | 6 +- vllm/model_executor/models/mamba.py | 31 +++++----- vllm/model_executor/models/minicpm.py | 38 ++++++------- vllm/model_executor/models/minicpm3.py | 12 ++-- vllm/model_executor/models/minicpmv.py | 56 +++++++------------ vllm/model_executor/models/mixtral.py | 34 +++++------ vllm/model_executor/models/mixtral_quant.py | 26 ++++----- vllm/model_executor/models/mllama.py | 49 +++++++--------- vllm/model_executor/models/molmo.py | 26 ++++----- vllm/model_executor/models/mpt.py | 26 ++++----- vllm/model_executor/models/nemotron.py | 34 +++++------ vllm/model_executor/models/olmo.py | 24 ++++---- vllm/model_executor/models/olmoe.py | 26 ++++----- vllm/model_executor/models/opt.py | 24 +++----- vllm/model_executor/models/orion.py | 26 ++++----- vllm/model_executor/models/paligemma.py | 13 ++--- vllm/model_executor/models/persimmon.py | 27 ++++----- vllm/model_executor/models/phi.py | 24 ++++---- vllm/model_executor/models/phi3_small.py | 26 ++++----- vllm/model_executor/models/phi3v.py | 14 ++--- vllm/model_executor/models/phimoe.py | 34 +++++------ vllm/model_executor/models/pixtral.py | 10 +--- vllm/model_executor/models/qwen.py | 27 ++++----- vllm/model_executor/models/qwen2.py | 23 +++----- vllm/model_executor/models/qwen2_audio.py | 12 ++-- vllm/model_executor/models/qwen2_cls.py | 11 ++-- vllm/model_executor/models/qwen2_moe.py | 26 ++++----- vllm/model_executor/models/qwen2_rm.py | 11 ++-- vllm/model_executor/models/qwen2_vl.py | 19 +++---- vllm/model_executor/models/solar.py | 34 +++++------ vllm/model_executor/models/stablelm.py | 24 ++++---- vllm/model_executor/models/starcoder2.py | 26 ++++----- vllm/model_executor/models/ultravox.py | 16 +++--- vllm/model_executor/models/xverse.py | 19 ++----- 69 files changed, 681 insertions(+), 963 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 997554f7dcccd..7d4b9654b54ab 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -34,7 +34,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) logger = init_logger(__name__) @@ -364,14 +365,13 @@ def forward( @support_torch_compile class ArcticModel(nn.Module): - def __init__( - self, - config: ArcticConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config - self.model = ArcticModel(config, - cache_config, - quant_config, - prefix=prefix) + self.model = ArcticModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.vocab_size, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 8e1dab71b1f39..aabbd31192a40 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -253,13 +253,18 @@ def forward( @support_torch_compile class BaiChuanModel(nn.Module): - def __init__(self, - config: PretrainedConfig, - position_embedding: str, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + position_embedding: str = "ROPE", + ) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def __init__( self, + *, vllm_config: VllmConfig, prefix: str = "", position_embedding: str = "ROPE", ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config self.quant_config = quant_config - self.model = BaiChuanModel(config, position_embedding, cache_config, - quant_config) + self.model = BaiChuanModel(vllm_config=vllm_config, + prefix=prefix, + position_embedding=position_embedding) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): NOTE: the class name has a lower case 'c'. """ - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config if config.hidden_size == 4096: # baichuan2 7b - super().__init__(vllm_config, prefix, "ROPE") + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ROPE") else: # baichuan 13b, baichuan2 13b - super().__init__(vllm_config, prefix, "ALIBI") + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ALIBI") class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): NOTE: the class name has an upper case 'C'. """ - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ): - super().__init__(vllm_config, prefix, "ROPE") + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ROPE") diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index c6da6a590cf5a..a50a5a5b018e1 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -41,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .utils import maybe_prefix + logger = logging.get_logger(__name__) @@ -739,13 +741,14 @@ class BartModel(nn.Module): "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" ] - def __init__(self, - config: BartConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.padding_idx = config.pad_token_id @@ -810,20 +813,16 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, class BartForConditionalGeneration(nn.Module): base_model_prefix = "model" - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config # currently all existing BART models have `tie_word_embeddings` enabled assert config.tie_word_embeddings self.config = config - self.model = BartModel(config, - cache_config, - quant_config, - lora_config=lora_config) + self.model = BartModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 2b0f45c5603f5..614d2db8ccff6 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -21,6 +21,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput +from .utils import maybe_prefix + class BertEmbedding(nn.Module): @@ -309,12 +311,13 @@ def forward(self, hidden_states: torch.Tensor, class BertModel(nn.Module): - def __init__(self, - config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.embeddings = BertEmbedding(config) self.encoder = BertEncoder(config, cache_config, @@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module): _pooler: An instance of Pooler used for pooling operations. """ - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config pooler_config = vllm_config.model_config.pooler_config - self.model = BertModel(config, cache_config, quant_config) + self.model = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.CLS, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index cdc30eda2ab3c..03dc1d15ab697 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -23,7 +23,7 @@ get_max_blip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, init_vllm_registered_model, - merge_multimodal_embeddings) + maybe_prefix, merge_multimodal_embeddings) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo @@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -517,7 +513,7 @@ def __init__( self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, - prefix="language_model") + prefix=maybe_prefix(prefix, "language_model")) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 7540bc23efd88..2c14519fb9e0e 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -42,7 +42,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -221,14 +222,13 @@ def forward( @support_torch_compile class BloomModel(nn.Module): - def __init__( - self, - config: BloomConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.embed_dim = config.hidden_size # Embedding + LN Embedding @@ -288,11 +288,12 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = BloomModel(config, cache_config, quant_config) + self.transformer = BloomModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) if self.config.tie_word_embeddings: self.lm_head = self.transformer.word_embeddings else: diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index f79bad6190708..7b59c818e0b60 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -37,7 +37,8 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) # These configs are not part of the model config but the preprocessor # and processor files, so we hardcode them in the model file for now. @@ -831,14 +832,13 @@ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: class ChameleonModel(nn.Module): - def __init__( - self, - config: ChameleonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -924,19 +924,14 @@ def forward( class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config - self.model = ChameleonModel(config, cache_config, quant_config) + self.model = ChameleonModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index c14f2fcb15063..08ed84aa9c71a 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -39,7 +39,8 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) logger = init_logger(__name__) @@ -481,14 +482,13 @@ def forward( class ChatGLMModel(nn.Module): - def __init__( - self, - config: ChatGLMConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.embedding = VocabParallelEmbedding(config.padded_vocab_size, @@ -600,7 +600,6 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config multimodal_config = vllm_config.model_config.multimodal_config @@ -611,7 +610,9 @@ def __init__( self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) - self.transformer = ChatGLMModel(config, cache_config, quant_config) + self.transformer = ChatGLMModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) if self.config.tie_word_embeddings: self.transformer.output_layer.weight = ( self.transformer.embedding.weight) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index e921fa50b099e..cd5c1d6844716 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -49,7 +49,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) @torch.compile @@ -253,15 +254,14 @@ def forward( @support_torch_compile class CohereModel(nn.Module): - def __init__( - self, - config: CohereConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -332,14 +332,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {"embed_tokens": "input_embeddings"} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config @@ -353,10 +348,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, - cache_config, - quant_config, - lora_config=lora_config) + self.model = CohereModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index e3b3164cacde3..d5f9b903183d4 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -25,7 +25,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class DbrxRouter(nn.Module): @@ -294,14 +295,13 @@ def forward( class DbrxModel(nn.Module): - def __init__( - self, - config: DbrxConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.wte = VocabParallelEmbedding( config.vocab_size, config.d_model, @@ -357,7 +357,6 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config if config.tie_word_embeddings: @@ -365,7 +364,9 @@ def __init__( "tie_word_embeddings is not supported for Dbrx models.") self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, cache_config, quant_config) + self.transformer = DbrxModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index 3e7005efb39ca..b38fd9fa49c21 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -51,11 +51,7 @@ class DeciLMForCausalLM(LlamaForCausalLM): instead. """ - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index c90d3d250e4c5..a9bf1440c4d60 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -50,7 +50,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class DeepseekMLP(nn.Module): @@ -326,14 +327,13 @@ class DeepseekModel(nn.Module): fall_back_to_pt_during_load = False - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -383,18 +383,14 @@ def forward( class DeepseekForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = DeepseekModel(config, cache_config, quant_config) + self.model = DeepseekModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 0f391d8329a8e..4fb1eed15a2e7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -51,7 +51,8 @@ from .interfaces import SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class DeepseekV2MLP(nn.Module): @@ -408,14 +409,13 @@ class DeepseekV2Model(nn.Module): fall_back_to_pt_during_load = False - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -479,21 +479,14 @@ def forward( class DeepseekV2ForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = DeepseekV2Model(config, - cache_config, - quant_config, - prefix="model") + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 6bd73d20d340d..c902829994c7c 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -14,6 +14,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .utils import maybe_prefix + class EAGLE(nn.Module): """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 @@ -42,7 +44,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: architectures = getattr(self.config.model, "architectures", []) model_cls, _ = ModelRegistry.resolve_model_cls(architectures) - self.model = model_cls(vllm_config, prefix) + self.model = model_cls(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.fc = nn.Linear(config.model.hidden_size * 2, config.model.hidden_size, bias=getattr(self.config, "eagle_fc_bias", False)) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index fa6dbfe35b3ad..cd3e7da657e0e 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -54,7 +54,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class ExaoneGatedMLP(nn.Module): @@ -314,15 +315,14 @@ def forward( @support_torch_compile class ExaoneModel(nn.Module): - def __init__( - self, - config: ExaoneConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.padding_idx = config.pad_token_id lora_vocab = ((lora_config.lora_extra_vocab_size * @@ -438,14 +438,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "c_fc_1": ("gate_up_proj", 1), } - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -453,11 +448,8 @@ def __init__( self.lora_config = lora_config self.transformer = ExaoneModel( - config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model", + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 96ae119042277..562ee5517e7f1 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -48,7 +48,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -332,14 +333,13 @@ def forward( @support_torch_compile class FalconModel(nn.Module): - def __init__( - self, - config: FalconConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads @@ -408,11 +408,12 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = FalconModel(config, cache_config, quant_config) + self.transformer = FalconModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) # only Falcon-11B doesn't share lm_head weight with word embeddings # and previous Falcon model doesn't have tie_word_embeddings config # so we set tie_word_embeddings to True by default diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index b0d970d9fb572..971a71180164b 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -3,13 +3,10 @@ import torch import torch.nn as nn -from transformers import PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, @@ -23,11 +20,13 @@ class Florence2LanguageModel(nn.Module): - def __init__(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.padding_idx = config.pad_token_id @@ -93,15 +92,14 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, class Florence2LanguageForConditionalGeneration(nn.Module): - def __init__(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + self.config = config - self.model = Florence2LanguageModel(config, - cache_config=cache_config, - quant_config=quant_config) + self.model = Florence2LanguageModel(vllm_config=vllm_config, + prefix=prefix) embed_scale = math.sqrt( config.d_model) if config.scale_embedding else 1.0 @@ -189,17 +187,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): class Florence2ForConditionalGeneration(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config # TODO(Isotr0py): Add vision backbone self.language_model = Florence2LanguageForConditionalGeneration( - config=config.text_config, - cache_config=cache_config, - quant_config=quant_config) + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=prefix, + ) @property def sampler(self): diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 4e0cbfb9cbf58..55baba809e58f 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -258,14 +258,13 @@ def forward( @support_torch_compile class GemmaModel(nn.Module): - def __init__( - self, - config: GemmaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.embed_tokens = VocabParallelEmbedding( @@ -372,14 +371,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -389,9 +383,7 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - self.model = GemmaModel(config, - cache_config, - quant_config, + self.model = GemmaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 773d3b72ec418..eeb3fd98a7eac 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -43,7 +43,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) logger = init_logger(__name__) @@ -243,11 +244,7 @@ def forward( @support_torch_compile class Gemma2Model(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -399,13 +396,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "up_proj": ("gate_up_proj", 1), } - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config del lora_config # Unused. @@ -414,7 +406,8 @@ def __init__( # currently all existing Gemma models have `tie_word_embeddings` enabled assert config.tie_word_embeddings self.quant_config = quant_config - self.model = Gemma2Model(config, cache_config, quant_config) + self.model = Gemma2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) self.sampler = get_sampler() @@ -471,14 +464,11 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): _pooler: An instance of Pooler used for pooling operations. """ - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - self.model = Gemma2Model(vllm_config, prefix) + self.model = Gemma2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self._pooler = Pooler.from_config_with_defaults( vllm_config.model_config.pooler_config, pooling_type=PoolingType.LAST, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index c3fc47db79986..fcff7ec2e01eb 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -42,7 +42,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class GPT2Attention(nn.Module): @@ -184,14 +185,13 @@ def forward( @support_torch_compile class GPT2Model(nn.Module): - def __init__( - self, - config: GPT2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config assert not config.add_cross_attention assert not config.scale_attn_by_inverse_layer_idx @@ -247,14 +247,12 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(config, - cache_config, - quant_config, - prefix="transformer") + self.transformer = GPT2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index ea1614d966365..ae1495ebd7914 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,7 +25,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -189,15 +189,14 @@ def forward( @support_torch_compile class GPTBigCodeModel(nn.Module): - def __init__( - self, - config: GPTBigCodeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config assert not config.add_cross_attention @@ -265,7 +264,6 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -273,8 +271,8 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, cache_config, quant_config, - lora_config) + self.transformer = GPTBigCodeModel(vllm_config=vllm_config, + prefix=prefix) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 58cff67c69051..610795b084b44 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -42,7 +42,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class GPTJAttention(nn.Module): @@ -177,14 +178,13 @@ def forward( @support_torch_compile class GPTJModel(nn.Module): - def __init__( - self, - config: GPTJConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.embed_dim = config.n_embd self.wte = VocabParallelEmbedding( @@ -236,12 +236,13 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(config, cache_config, quant_config) + self.transformer = GPTJModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 27b2577a8cdca..f5603772e9862 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -41,7 +41,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class GPTNeoXAttention(nn.Module): @@ -189,14 +190,13 @@ def forward( @support_torch_compile class GPTNeoXModel(nn.Module): - def __init__( - self, - config: GPTNeoXConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.embed_in = VocabParallelEmbedding( @@ -249,11 +249,11 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) + self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "gpt_neox")) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index c3e23b7138e7f..d1e6e31f2b8d1 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -52,7 +52,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers, + maybe_prefix) class GraniteMLP(nn.Module): @@ -257,15 +258,14 @@ def forward( @support_torch_compile class GraniteModel(nn.Module): - def __init__( - self, - config: GraniteConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * @@ -370,25 +370,17 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "up_proj": ("gate_up_proj", 1), } - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config - self.model = GraniteModel(config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model") + self.model = GraniteModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 73f7c106e3d39..2ed115c56af45 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -47,7 +47,7 @@ from . import mixtral from .interfaces import SupportsLoRA, SupportsPP -from .utils import make_layers +from .utils import make_layers, maybe_prefix class GraniteMoeMoE(nn.Module): @@ -247,15 +247,14 @@ def forward( @support_torch_compile class GraniteMoeModel(nn.Module): - def __init__( - self, - config: GraniteMoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -333,25 +332,17 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config - self.model = GraniteMoeModel(config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model") + self.model = GraniteMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index b676171b556a7..b234b602e6fbf 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -22,17 +22,15 @@ from PIL import Image from torch import nn # Temporary solution for transformers below 4.46.0. -from transformers import PretrainedConfig as Idefics3Config from transformers import ProcessorMixin as Idefics3ImageProcessor from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -48,7 +46,8 @@ # yapf: enable from .interfaces import SupportsMultiModal from .llama import LlamaModel -from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) logger = init_logger(__name__) @@ -417,13 +416,13 @@ def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor: class Idefics3Model(nn.Module): - def __init__( - self, - config: Idefics3Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.padding_idx = self.config.text_config.pad_token_id self.vocab_size = self.config.text_config.vocab_size @@ -613,22 +612,18 @@ def forward( @INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config - self.model = Idefics3Model(config, cache_config, quant_config) + self.model = Idefics3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.image_token_id = self.config.image_token_id self.lm_head = ParallelLMHead( diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index cbedd0c8a0130..21fa6983063b8 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -250,14 +250,13 @@ def forward( @support_torch_compile class InternLM2Model(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -317,20 +316,13 @@ def forward( class InternLM2ForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = InternLM2Model(config, - cache_config, - quant_config, + self.model = InternLM2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.output = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 51e2c64d5552d..34889d691a934 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -104,14 +104,13 @@ def forward( class InternLM2VEModel(InternLM2Model): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config, cache_config, quant_config) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: InternLM2VEDecoderLayer( @@ -159,12 +158,8 @@ def forward( class InternLM2VEForCausalLM(InternLM2ForCausalLM): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: - super().__init__(vllm_config, prefix=prefix) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 42bccf71273b3..77efc9a26ef7a 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -35,7 +35,7 @@ get_clip_num_patches) from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) + maybe_prefix, merge_multimodal_embeddings) IMG_START = '' IMG_END = '' @@ -435,13 +435,13 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config=quant_config, is_mono=self.is_mono, - prefix="vision_model", + prefix=maybe_prefix(prefix, "vision_model"), ) self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, - prefix="language_model") + prefix=maybe_prefix(prefix, "language_model")) self.mlp1 = self._init_mlp1(config) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index ae3f5b01d5cce..4dc9271703a8d 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -44,7 +44,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class SwiGLUActivation(nn.Module): @@ -215,14 +216,13 @@ def forward( @support_torch_compile class JAISModel(nn.Module): - def __init__( - self, - config: JAISConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config assert not config.add_cross_attention assert not config.scale_attn_by_inverse_layer_idx @@ -293,11 +293,12 @@ def __init__( ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = JAISModel(config, cache_config, quant_config) + self.transformer = JAISModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 72eb1017c2868..88fb8d5cf555a 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -29,6 +29,7 @@ _get_graph_batch_size) from .interfaces import HasInnerState, SupportsLoRA +from .utils import maybe_prefix KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -258,14 +259,14 @@ def forward( class JambaModel(nn.Module): - def __init__( - self, - config: JambaConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.padding_idx = config.pad_token_id lora_vocab = ((lora_config.lora_extra_vocab_size * @@ -348,14 +349,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): } embedding_padding_modules = ["lm_head"] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, \ @@ -364,10 +360,8 @@ def __init__( super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = JambaModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) + self.model = JambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b765912387e2e..2472128976d88 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -271,15 +271,14 @@ def forward( @support_torch_compile class LlamaModel(nn.Module): - def __init__( - self, - config: LlamaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * @@ -492,24 +491,16 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "norm": "model.norm" } - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config - self.model = LlamaModel(config, - cache_config, - quant_config, - lora_config=lora_config, + self.model = LlamaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size @@ -652,23 +643,12 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config pooler_config = vllm_config.model_config.pooler_config - self.model = LlamaModel(config, - cache_config, - quant_config, - lora_config, + self.model = LlamaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self._pooler = Pooler.from_config_with_defaults( pooler_config, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c98462537728a..ca963fa1c52ea 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -32,7 +32,7 @@ dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) + maybe_prefix, merge_multimodal_embeddings) class LlavaImagePixelInputs(TypedDict): @@ -282,7 +282,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix="vision_tower") + prefix=maybe_prefix(prefix, "vision_tower")) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, @@ -291,7 +291,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, - prefix="language_model") + prefix=maybe_prefix(prefix, "language_model")) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index f187f8105b96a..0b621a23ec980 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -31,7 +31,7 @@ dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, - init_vllm_registered_model) + init_vllm_registered_model, maybe_prefix) class LlavaNextImagePixelInputs(TypedDict): @@ -296,7 +296,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix="vision_tower") + prefix=maybe_prefix(prefix, "vision_tower")) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( @@ -307,7 +307,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, - prefix="language_model") + prefix=maybe_prefix(prefix, "language_model")) # The same model class supports both language generation and embedding # because the architecture name is the same diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index eceb0c0ab52df..b030c2f5fdc47 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -29,7 +29,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip) from .utils import (AutoWeightsLoader, init_vllm_registered_model, - merge_multimodal_embeddings) + maybe_prefix, merge_multimodal_embeddings) # For profile run _MAX_FRAMES_PER_VIDEO = 32 @@ -267,7 +267,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix="vision_tower") + prefix=maybe_prefix(prefix, "vision_tower")) self.vision_resampler = LlavaNextVideoPooler(config) self.multi_modal_projector = LlavaNextMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, @@ -276,7 +276,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, - prefix="language_model") + prefix=maybe_prefix(prefix, "language_model")) self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 64d373ce91509..c129f140d8d12 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -35,7 +35,7 @@ dummy_video_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) + maybe_prefix, merge_multimodal_embeddings) # Result in the max possible feature size (2x2 grid of 336x336px tiles) MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 @@ -418,12 +418,12 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix="vision_tower") + prefix=maybe_prefix(prefix, "vision_tower")) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, - prefix="language_model") + prefix=maybe_prefix(prefix, "language_model")) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 49e43f8cc683c..55c575e22a0f6 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -26,6 +26,8 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) +from .utils import maybe_prefix + KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -73,14 +75,14 @@ def forward( class MambaModel(nn.Module): - def __init__( - self, - config: MambaConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.padding_idx = config.pad_token_id lora_vocab = ((lora_config.lora_extra_vocab_size * @@ -130,14 +132,9 @@ def forward( class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, \ @@ -146,10 +143,8 @@ def __init__( super().__init__() self.config = config self.scheduler_config = scheduler_config - self.backbone = MambaModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) + self.backbone = MambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "backbone")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 559d9c4dd35bf..2db953329fd91 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -53,7 +53,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class MiniCPMMoE(nn.Module): @@ -351,15 +352,14 @@ def forward( @support_torch_compile class MiniCPMModel(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.cache_config = cache_config self.quant_config = quant_config @@ -461,24 +461,22 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + self.prefix = prefix + self.vllm_config = vllm_config self.config = config self.lora_config = lora_config self.cache_config = cache_config self.quant_config = quant_config self.num_experts = getattr(self.config, "num_experts", 0) - self._init_model() + self._init_model(vllm_config=vllm_config, prefix=prefix) unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -502,11 +500,9 @@ def __init__( self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - def _init_model(self): - self.model = MiniCPMModel(config=self.config, - cache_config=self.cache_config, - quant_config=self.quant_config, - lora_config=self.lora_config) + def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): + self.model = MiniCPMModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) def forward( self, diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index eeedf55cf3e57..278c4bbe6e563 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -28,7 +28,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -40,7 +40,7 @@ MiniCPMForCausalLM, MiniCPMModel) -from .utils import make_layers +from .utils import make_layers, maybe_prefix class MiniCPM3Attention(nn.Module): @@ -238,8 +238,6 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): # `embedding_modules` and `embedding_padding_modules` # are inherited from MiniCPMForCausalLM - def _init_model(self): - self.model = MiniCPM3Model(config=self.config, - cache_config=self.cache_config, - quant_config=self.quant_config, - lora_config=self.lora_config) + def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): + self.model = MiniCPM3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 9458204c5a038..aae534c0b5949 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -34,7 +34,7 @@ from typing_extensions import NotRequired from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -59,7 +59,7 @@ from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import is_pp_missing_parameter +from .utils import is_pp_missing_parameter, maybe_prefix _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", @@ -390,7 +390,6 @@ def __init__( ): config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but @@ -401,11 +400,11 @@ def __init__( self.multimodal_config = multimodal_config self.version = get_version_by_config(self.config) - self.llm = self.init_llm(config, - cache_config, - quant_config, - prefix="llm") - self.vpm = self.init_vision_module(config, quant_config, prefix="vpm") + self.llm = self.init_llm(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "llm")) + self.vpm = self.init_vision_module(config, + quant_config, + prefix=maybe_prefix(prefix, "vpm")) param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else @@ -414,13 +413,15 @@ def __init__( self.resampler = self.init_resampler(self.embed_dim, self.vision_dim, quant_config=quant_config, - prefix="resampler") + prefix=maybe_prefix( + prefix, "resampler")) self.resampler.to(device="cuda", dtype=param_dtype) # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix="llm.lm_head") + prefix=maybe_prefix( + prefix, "llm.lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() @@ -661,9 +662,7 @@ def get_mm_mapping(self) -> MultiModelKeys: def init_llm( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: raise NotImplementedError @@ -711,16 +710,10 @@ def __init__( def init_llm( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - - return LLMWrapper(MiniCPMModel(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix), name="model") def init_vision_module( @@ -875,15 +868,10 @@ def __init__( def init_llm( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(LlamaModel(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix), name="model") def init_vision_module( @@ -1022,16 +1010,10 @@ def __init__( def init_llm( self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - - return LLMWrapper(Qwen2Model(config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix), name="model") def init_vision_module( @@ -1151,4 +1133,4 @@ def __new__(cls, vllm_config: VllmConfig, prefix: str = ""): if instance_class is None: raise ValueError( "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") - return instance_class(vllm_config, prefix=prefix) + return instance_class(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 91ec3228c0d48..3eb2f60fd4fc7 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -48,7 +48,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class MixtralMoE(nn.Module): @@ -248,15 +249,14 @@ def forward( @support_torch_compile class MixtralModel(nn.Module): - def __init__( - self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -332,24 +332,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config - self.model = MixtralModel(config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model") + self.model = MixtralModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index aeac326776392..95cfb6f54dc10 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -49,7 +49,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class MixtralMLP(nn.Module): @@ -293,14 +294,13 @@ def forward( class MixtralModel(nn.Module): - def __init__( - self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -350,18 +350,14 @@ def forward( class MixtralForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, cache_config, quant_config) + self.model = MixtralModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 14aa515570f38..e5c1d28e6e7ea 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -33,7 +33,7 @@ import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, InputContext, TokenInputs, token_inputs) @@ -56,6 +56,7 @@ from .clip import CLIPMLP from .interfaces import SupportsMultiModal from .llama import LlamaDecoderLayer, LlamaMLP +from .utils import maybe_prefix logger = init_logger(__name__) MLLAMA_IMAGE_TOKEN_ID = 128256 @@ -939,15 +940,13 @@ class MllamaTextModel(nn.Module): config_class = config_mllama.MllamaTextConfig base_model_prefix = "model" - def __init__( - self, - config: config_mllama.MllamaTextConfig, - cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config.text_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, @@ -1029,18 +1028,14 @@ class MllamaForCausalLM(nn.Module): "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" ] - def __init__( - self, - config: config_mllama.MllamaTextConfig, - cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config.text_config + quant_config = vllm_config.quant_config + self.vocab_size = config.vocab_size - self.model = MllamaTextModel(config, - cache_config, - quant_config, + self.model = MllamaTextModel(vllm_config=vllm_config, prefix=f"{prefix}.model") self.lm_head = ParallelLMHead( config.vocab_size, @@ -1108,14 +1103,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): "up_proj": ("gate_up_proj", 1), } - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size @@ -1127,12 +1117,11 @@ def __init__( self.vision_model = MllamaVisionModel(config.vision_config, quant_config, - prefix="vision_model") + prefix=maybe_prefix( + prefix, "vision_model")) self.language_model = MllamaForCausalLM( - config.text_config, - cache_config=cache_config, - quant_config=quant_config, - prefix="language_model", + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), ) self.multi_modal_projector = ColumnParallelLinear( config.vision_config.vision_output_dim, @@ -1140,7 +1129,7 @@ def __init__( bias=True, quant_config=quant_config, gather_output=True, - prefix="multi_modal_projector", + prefix=maybe_prefix(prefix, "multi_modal_projector"), ) self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index cd462c4d0495e..035a1e2ab7b02 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -44,7 +44,8 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (get_vit_attn_backend, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -716,14 +717,13 @@ def forward( @support_torch_compile class MolmoModel(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.embedding_size = config.embedding_size or config.vocab_size @@ -1024,14 +1024,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config @@ -1040,7 +1035,8 @@ def __init__( vision_config = VisionBackboneConfig() self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) - self.model = MolmoModel(config, cache_config, quant_config) + self.model = MolmoModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if self.config.weight_tying: self.lm_head = self.model.transformer.wte diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 672c8e9c22260..e15c0fe8db060 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -26,7 +26,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) def _get_alibi_slopes( @@ -207,14 +208,13 @@ def forward( @support_torch_compile class MPTModel(nn.Module): - def __init__( - self, - config: MPTConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + assert config.embedding_fraction == 1.0 assert config.norm_type == "low_precision_layernorm" @@ -267,20 +267,16 @@ def forward( class MPTForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config assert config.tie_word_embeddings self.quant_config = quant_config - self.transformer = MPTModel(config, cache_config, quant_config) + self.transformer = MPTModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "transformer")) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 5991cce642981..e09d7088a69ce 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -27,7 +27,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -47,7 +47,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) # The architecture is pretty similar to Llama, with these changes: # - There is no gate_proj, just up_proj @@ -293,15 +294,14 @@ def forward( @support_torch_compile class NemotronModel(nn.Module): - def __init__( - self, - config: NemotronConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * @@ -401,14 +401,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "v_proj": ("qkv_proj", 2), } - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config assert isinstance(config, NemotronConfig) @@ -416,11 +411,8 @@ def __init__( self.config = config self.lora_config = lora_config - self.model = NemotronModel(config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model") + self.model = NemotronModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 6905f8521a8c3..3467ae5896494 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -46,7 +46,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class OlmoAttention(nn.Module): @@ -224,12 +225,13 @@ def forward( @support_torch_compile class OlmoModel(nn.Module): - def __init__(self, - config: OlmoConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, @@ -291,17 +293,13 @@ class OlmoForCausalLM(nn.Module, SupportsPP): Extremely barebones HF model wrapper. """ - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config - self.model = OlmoModel(config, cache_config, quant_config) + self.model = OlmoModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 8fa90d17003af..3d31919edd862 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -38,7 +38,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class OlmoeMoE(nn.Module): @@ -243,14 +244,13 @@ def forward( @support_torch_compile class OlmoeModel(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -309,18 +309,14 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = OlmoeModel(config, cache_config, quant_config) + self.model = OlmoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d378956b68cfc..58b6107eba347 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -293,14 +293,13 @@ def forward( @support_torch_compile class OPTModel(nn.Module): - def __init__( - self, - config: OPTConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.decoder = OPTDecoder(config, cache_config, quant_config, @@ -342,21 +341,14 @@ class OPTForCausalLM(nn.Module, SupportsPP): ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2." ] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config super().__init__() self.config = config self.quant_config = quant_config - self.model = OPTModel(config, - cache_config, - quant_config, + self.model = OPTModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if self.config.tie_word_embeddings: self.lm_head = self.model.decoder.embed_tokens diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index b400d4e3f5228..38821c8288347 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -29,7 +29,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class OrionMLP(nn.Module): @@ -208,14 +209,13 @@ def forward( @support_torch_compile class OrionModel(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -268,18 +268,14 @@ def forward( class OrionForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = OrionModel(config, cache_config, quant_config) + self.model = OrionModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 69b7fe9d56847..eea229359255e 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -20,7 +20,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import (AutoWeightsLoader, init_vllm_registered_model, - merge_multimodal_embeddings) + maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -131,11 +131,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -145,7 +141,8 @@ def __init__( self.vision_tower = SiglipVisionModel(config.vision_config, quant_config, - prefix="vision_tower") + prefix=maybe_prefix( + prefix, "vision_tower")) self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, projection_dim=config.vision_config.projection_dim) @@ -155,7 +152,7 @@ def __init__( self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, - prefix="language_model") + prefix=maybe_prefix(prefix, "language_model")) logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index a86e2c1b4e4a1..2e34a7cc30873 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -45,7 +45,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class PersimmonMLP(nn.Module): @@ -212,12 +213,13 @@ def forward( @support_torch_compile class PersimmonModel(nn.Module): - def __init__(self, - config: PersimmonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, @@ -265,20 +267,13 @@ def forward( class PersimmonForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.config = config self.vocab_size = config.vocab_size - self.model = PersimmonModel(config, - cache_config=cache_config, - quant_config=quant_config) + self.model = PersimmonModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=False) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index fef921528b042..262f6996fc374 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -60,7 +60,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class PhiAttention(nn.Module): @@ -196,12 +197,13 @@ def forward( @support_torch_compile class PhiModel(nn.Module): - def __init__(self, - config: PhiConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, @@ -277,14 +279,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config @@ -294,7 +291,8 @@ def __init__( self.quant_config = quant_config - self.model = PhiModel(config, cache_config, quant_config) + self.model = PhiModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index de1b09eba6c6d..8a5fb6d303e60 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -24,7 +24,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) def load_column_parallel_weight(param: torch.nn.Parameter, @@ -299,14 +300,13 @@ def forward( class Phi3SmallModel(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) @@ -363,18 +363,14 @@ def forward( class Phi3SmallForCausalLM(nn.Module, SupportsPP): _tied_weights_keys = ["lm_head.weight"] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Phi3SmallModel(config, cache_config, quant_config) + self.model = Phi3SmallModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.vocab_size = config.vocab_size self.mup_width_multiplier = config.mup_width_multiplier self.lm_head = ParallelLMHead( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 65131d61673a3..4b5dc944bce4b 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -45,7 +45,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -525,11 +525,7 @@ def input_processor_for_phi3v(ctx: InputContext, @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -544,12 +540,14 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=quant_config, - prefix="model.embed_tokens", + prefix=maybe_prefix(prefix, "model.embed_tokens"), ) # TODO: Optionally initializes this for supporting input embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding( - config, quant_config, prefix="model.vision_embed_tokens") + config, + quant_config, + prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) # The prefix is empty intentionally because default prefix of # LlamaForCausalLM is "model" diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 17d00c0ede2b2..6d71a8949111b 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -48,7 +48,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class PhiMoEConfig(PretrainedConfig): @@ -432,15 +433,14 @@ def forward( @support_torch_compile class PhiMoEModel(nn.Module): - def __init__( - self, - config: PhiMoEConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.padding_idx = config.pad_token_id lora_vocab = ((lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0) @@ -529,23 +529,15 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config - self.model = PhiMoEModel(config, - cache_config, - quant_config, - lora_config=lora_config) + self.model = PhiMoEModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 93919c9c051c0..6bd5e119dd2dd 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -38,7 +38,7 @@ from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP -from .utils import init_vllm_registered_model +from .utils import init_vllm_registered_model, maybe_prefix try: from xformers import ops as xops @@ -152,11 +152,7 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config @@ -176,7 +172,7 @@ def __init__( self.language_model = init_vllm_registered_model( config.text_config, vllm_config=vllm_config, - prefix="language_model") + prefix=maybe_prefix(prefix, "language_model")) self.vision_encoder = VisionTransformer(self.vision_args) self.vision_language_adapter = VisionLanguageAdapter( diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d3f10ee7c85ca..cc70099361dd2 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -50,7 +50,8 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) logger = init_logger(__name__) @@ -552,14 +553,13 @@ def forward( @support_torch_compile class QWenModel(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.vocab_size = config.vocab_size @@ -865,20 +865,17 @@ def dummy_data_for_qwen( class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config - self.transformer = QWenModel(config, cache_config, quant_config) + self.transformer = QWenModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b0156a25ca5cf..2195ce49aa9a7 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -240,14 +240,13 @@ def forward( @support_torch_compile class Qwen2Model(nn.Module): - def __init__( - self, - config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -403,11 +402,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "up_proj": ("gate_up_proj", 1), } - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -429,9 +424,7 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(config, - cache_config, - quant_config, + self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 1057720e8c308..d30950361ad89 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -264,14 +264,9 @@ def input_mapper_for_qwen2_audio( class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config @@ -283,8 +278,9 @@ def __init__( self.quant_config = quant_config - self.language_model = Qwen2Model(config.text_config, cache_config, - quant_config) + self.language_model = Qwen2Model( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=prefix) self.unpadded_vocab_size = config.text_config.vocab_size if config.text_config.tie_word_embeddings: self.lm_head = self.language_model.embed_tokens diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 25ecf76e35f22..020af88aadd98 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -17,7 +17,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .utils import AutoWeightsLoader +from .utils import AutoWeightsLoader, maybe_prefix class Qwen2ForSequenceClassification(nn.Module): @@ -43,11 +43,7 @@ class Qwen2ForSequenceClassification(nn.Module): embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -70,7 +66,8 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(config, cache_config, quant_config) + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.score = RowParallelLinear(config.hidden_size, config.num_labels, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b1177f9c59063..51c0cd5664fd2 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -54,7 +54,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class Qwen2MoeMLP(nn.Module): @@ -315,14 +316,13 @@ def forward( @support_torch_compile class Qwen2MoeModel(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -377,18 +377,14 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Qwen2MoeModel(config, cache_config, quant_config) + self.model = Qwen2MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 1f9411241bdd6..89768ec9dff37 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -18,7 +18,7 @@ from .interfaces import SupportsPP from .qwen2 import Qwen2Model -from .utils import AutoWeightsLoader +from .utils import AutoWeightsLoader, maybe_prefix class ReLU(nn.Module): @@ -55,11 +55,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -82,7 +78,8 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(config, cache_config, quant_config) + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.score = nn.Sequential( ColumnParallelLinear(config.hidden_size, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ab80c1494d067..13109758767df 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -70,7 +70,7 @@ from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, get_vit_attn_backend, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory) + make_empty_intermediate_tensors_factory, maybe_prefix) logger = init_logger(__name__) @@ -966,11 +966,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, embedding_modules = {} embedding_padding_modules = [] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -986,13 +982,11 @@ def __init__( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), - prefix="visual", + prefix=maybe_prefix(prefix, "visual"), ) - self.model = Qwen2Model(config, - cache_config, - quant_config, - prefix="model") + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: if config.tie_word_embeddings: @@ -1001,7 +995,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix="lm_head") + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index ffabac8292dbd..4f03ca501fb68 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -53,7 +53,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class SolarMLP(nn.Module): @@ -266,15 +267,14 @@ def forward( @support_torch_compile class SolarModel(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.padding_idx = config.pad_token_id lora_vocab = ((lora_config.lora_extra_vocab_size * @@ -409,25 +409,17 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "up_proj": ("gate_up_proj", 1), } - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config self.model = SolarModel( - config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model", + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 975d316977c37..1125f9e9f9617 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -43,7 +43,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class StablelmMLP(nn.Module): @@ -193,12 +194,13 @@ def forward( class StableLMEpochModel(nn.Module): - def __init__(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = '') -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -245,18 +247,14 @@ def forward( class StablelmForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(config, cache_config, quant_config) + self.model = StableLMEpochModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index ae61aa4e248a5..ce7a7957f52c4 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -43,7 +43,8 @@ from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class Starcoder2Attention(nn.Module): @@ -195,12 +196,13 @@ def forward( @support_torch_compile class Starcoder2Model(nn.Module): - def __init__(self, - config: Starcoder2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -245,19 +247,13 @@ def forward( class Starcoder2ForCausalLM(nn.Module, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config - self.model = Starcoder2Model(config, - cache_config, - quant_config=quant_config) + self.model = Starcoder2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index d47f0091e0f9f..9fde22c016de0 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -34,7 +34,7 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, + init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings_from_map) _AUDIO_PLACEHOLDER_TOKEN = 128002 @@ -339,11 +339,7 @@ def forward( @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config multimodal_config = vllm_config.model_config.multimodal_config @@ -354,6 +350,8 @@ def __init__( self.secondary_weights = [] self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: + # this prefix is not for initialization, but for loading weights + # note the trailing dot self.secondary_weights.append( DefaultModelLoader.Source( model_or_path=config.audio_model_id, @@ -362,8 +360,12 @@ def __init__( )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config, prefix="language_model") + config.text_config, + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model")) if config.text_model_id is not None: + # this prefix is not for initialization, but for loading weights + # note the trailing dot self.secondary_weights.append( DefaultModelLoader.Source(model_or_path=config.text_model_id, revision=None, diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 7afb99176077b..153527da20d75 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -46,7 +46,8 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) class XverseMLP(nn.Module): @@ -223,11 +224,7 @@ def forward( @support_torch_compile class XverseModel(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -315,15 +312,10 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -331,7 +323,8 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - self.model = XverseModel(config, cache_config, quant_config) + self.model = XverseModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config)