Skip to content

Commit

Permalink
Fix prefix strings for quantized VLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Oct 28, 2024
1 parent 2adb440 commit 259bc5b
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 49 deletions.
5 changes: 4 additions & 1 deletion vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,10 @@ def __init__(self,
)

self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
config.text_config,
cache_config,
quant_config,
prefix="language_model")

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
Expand Down
58 changes: 39 additions & 19 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

logger = init_logger(__name__)

Expand Down Expand Up @@ -83,16 +84,23 @@ def __init__(
hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)

def forward(self, x):
Expand All @@ -104,15 +112,18 @@ def forward(self, x):

class GemmaAttention(nn.Module):

def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int = 8192,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int = 8192,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -142,12 +153,14 @@ def __init__(self,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

self.rotary_emb = get_rope(
Expand Down Expand Up @@ -186,6 +199,7 @@ def __init__(
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand All @@ -198,13 +212,15 @@ def __init__(
rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -259,8 +275,8 @@ def __init__(
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
),
lambda prefix: GemmaDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down Expand Up @@ -366,6 +382,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()

Expand All @@ -375,7 +392,10 @@ def __init__(
self.lora_config = lora_config

self.quant_config = quant_config
self.model = GemmaModel(config, cache_config, quant_config)
self.model = GemmaModel(config,
cache_config,
quant_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,10 @@ def __init__(self,
)

self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
config.text_config,
cache_config,
quant_config,
prefix="language_model")

self.mlp1 = self._init_mlp1(config)

Expand Down
20 changes: 15 additions & 5 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def init_vision_tower_for_llava(
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
):
vision_config = hf_config.vision_config

Expand All @@ -224,23 +225,26 @@ def init_vision_tower_for_llava(
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
quant_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
quant_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
)
elif isinstance(vision_config, PixtralVisionConfig):
return PixtralHFVisionModel(
vision_config,
quant_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
)

msg = f"Unsupported vision config: {type(vision_config)}"
Expand Down Expand Up @@ -274,14 +278,20 @@ def __init__(self,

# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config, quant_config, require_post_norm=False)
config,
quant_config,
require_post_norm=False,
prefix="vision_tower")
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)

self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
config.text_config,
cache_config,
quant_config,
prefix="language_model")

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,10 @@ def __init__(self,

# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config, quant_config, require_post_norm=False)
config,
quant_config,
require_post_norm=False,
prefix="vision_tower")
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
Expand All @@ -302,7 +305,10 @@ def __init__(self,
projector_hidden_act=config.projector_hidden_act)

self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
config.text_config,
cache_config,
quant_config,
prefix="language_model")

# The same model class supports both language generation and embedding
# because the architecture name is the same
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,20 @@ def __init__(self,

# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config, quant_config, require_post_norm=False)
config,
quant_config,
require_post_norm=False,
prefix="vision_tower")
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
config.text_config,
cache_config,
quant_config,
prefix="language_model")

self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,16 @@ def __init__(self,

# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config, quant_config, require_post_norm=False)
config,
quant_config,
require_post_norm=False,
prefix="vision_tower")
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
config.text_config,
cache_config,
quant_config,
prefix="language_model")
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))

Expand Down
Loading

0 comments on commit 259bc5b

Please sign in to comment.