Skip to content

Commit

Permalink
[6/N] pass whole config to inner model (vllm-project#10205)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 11, 2024
1 parent f0f2e56 commit f89d18f
Show file tree
Hide file tree
Showing 69 changed files with 681 additions and 963 deletions.
23 changes: 10 additions & 13 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@

from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

logger = init_logger(__name__)

Expand Down Expand Up @@ -364,14 +365,13 @@ def forward(
@support_torch_compile
class ArcticModel(nn.Module):

def __init__(
self,
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
Expand Down Expand Up @@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.model = ArcticModel(config,
cache_config,
quant_config,
prefix=prefix)
self.model = ArcticModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.vocab_size,
Expand Down
48 changes: 26 additions & 22 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,18 @@ def forward(
@support_torch_compile
class BaiChuanModel(nn.Module):

def __init__(self,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
) -> None:
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config)
self.model = BaiChuanModel(vllm_config=vllm_config,
prefix=prefix,
position_embedding=position_embedding)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
Expand Down Expand Up @@ -438,26 +444,24 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has a lower case 'c'.
"""

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(vllm_config, prefix, "ROPE")
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ROPE")
else: # baichuan 13b, baichuan2 13b
super().__init__(vllm_config, prefix, "ALIBI")
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ALIBI")


class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 7B.
NOTE: the class name has an upper case 'C'.
"""

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config, prefix, "ROPE")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ROPE")
23 changes: 11 additions & 12 deletions vllm/model_executor/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .utils import maybe_prefix

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -739,13 +741,14 @@ class BartModel(nn.Module):
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
]

def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config

self.config = config

self.padding_idx = config.pad_token_id
Expand Down Expand Up @@ -810,20 +813,16 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
class BartForConditionalGeneration(nn.Module):
base_model_prefix = "model"

def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.config = config
self.model = BartModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.model = BartModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

self.unpadded_vocab_size = config.vocab_size
if lora_config:
Expand Down
25 changes: 11 additions & 14 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .utils import maybe_prefix


class BertEmbedding(nn.Module):

Expand Down Expand Up @@ -309,12 +311,13 @@ def forward(self, hidden_states: torch.Tensor,

class BertModel(nn.Module):

def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder(config,
cache_config,
Expand Down Expand Up @@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module):
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(config, cache_config, quant_config)
self.model = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
maybe_prefix, merge_multimodal_embeddings)

# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
Expand Down Expand Up @@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -517,7 +513,7 @@ def __init__(
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
Expand Down
21 changes: 11 additions & 10 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@

from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
Expand Down Expand Up @@ -221,14 +222,13 @@ def forward(
@support_torch_compile
class BloomModel(nn.Module):

def __init__(
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.embed_dim = config.hidden_size

# Embedding + LN Embedding
Expand Down Expand Up @@ -288,11 +288,12 @@ def __init__(
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config)
self.transformer = BloomModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
Expand Down
27 changes: 11 additions & 16 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
Expand Down Expand Up @@ -831,14 +832,13 @@ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:

class ChameleonModel(nn.Module):

def __init__(
self,
config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -924,19 +924,14 @@ def forward(
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):

def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.model = ChameleonModel(config, cache_config, quant_config)
self.model = ChameleonModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
Expand Down
Loading

0 comments on commit f89d18f

Please sign in to comment.