Skip to content

Commit

Permalink
[VLM] Enable overriding whether post layernorm is used in vision enco…
Browse files Browse the repository at this point in the history
…der + fix quant args (vllm-project#9217)

Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
2 people authored and mfournioux committed Nov 20, 2024
1 parent 0cf432f commit 0dbeff9
Show file tree
Hide file tree
Showing 18 changed files with 551 additions and 253 deletions.
20 changes: 16 additions & 4 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
Expand All @@ -21,10 +22,12 @@ def __init__(
weight_bits: int,
group_size: int,
zero_point: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.modules_to_not_convert = modules_to_not_convert or []

if self.weight_bits != 4:
raise ValueError(
Expand All @@ -35,7 +38,8 @@ def __init__(
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")

def get_name(self) -> str:
return "awq"
Expand All @@ -61,18 +65,26 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQLinearMethod"]:
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]


def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)


class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.
Expand Down
87 changes: 59 additions & 28 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def input_processor_for_blip(
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):

def __init__(self, config: BlipVisionConfig):
def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
super().__init__()

self.config = config
Expand Down Expand Up @@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):

def __init__(
self,
config: BlipVisionConfig,
config: Union[BlipVisionConfig, Blip2VisionConfig],
quant_config: Optional[QuantizationConfig] = None,
):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
Expand All @@ -189,11 +190,13 @@ def __init__(
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.projection",
)

self.tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -235,9 +238,12 @@ def forward(

class BlipMLP(nn.Module):

def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()

self.config = config
Expand All @@ -246,11 +252,13 @@ def __init__(self,
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.fc2")

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
Expand All @@ -262,24 +270,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

class BlipEncoderLayer(nn.Module):

def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()

# fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config,
quant_config=quant_config)
self.self_attn = BlipParallelAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
self.mlp = BlipMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

Expand Down Expand Up @@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
config: BlipConfig
"""

def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()

self.config = config
Expand All @@ -321,8 +340,10 @@ def __init__(self,
num_hidden_layers = num_hidden_layers_override

self.layers = nn.ModuleList([
BlipEncoderLayer(config=config, quant_config=quant_config)
for _ in range(num_hidden_layers)
BlipEncoderLayer(config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_hidden_layers)
])

def forward(self, inputs_embeds: torch.Tensor):
Expand All @@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
config_class = BlipVisionConfig
main_input_name = "pixel_values"

def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
prefix: str = "",
) -> None:
super().__init__()

tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -354,19 +380,24 @@ def __init__(self,
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)

num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:

# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers

if require_post_norm:
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def __init__(self,
self.multimodal_config = multimodal_config

# TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(config.vision_config)
self.vision_model = BlipVisionModel(config.vision_config, quant_config)

self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens,
Expand Down
Loading

0 comments on commit 0dbeff9

Please sign in to comment.