Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add BNB quantization support for Mllama #9720

Merged
merged 10 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand All @@ -23,7 +24,7 @@ def __init__(
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[Any] = None,
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 0.0,
) -> None:

Expand All @@ -34,11 +35,15 @@ def __init__(
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold

def __repr__(self) -> str:
return "BitsAndBytesConfig"
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
f"load_in_4bit={self.load_in_4bit}, "
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")

@classmethod
def get_name(self) -> str:
Expand Down Expand Up @@ -102,15 +107,21 @@ def get_safe_value(config, keys, default_value=None):
llm_int8_threshold=llm_int8_threshold)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
return any(module_name in prefix for module_name in llm_int8_skip_modules)


class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
Expand Down Expand Up @@ -211,6 +222,11 @@ def _apply_8bit_weight(
from bitsandbytes import MatmulLtState, matmul

original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)

qweight = layer.qweight
Expand Down Expand Up @@ -265,6 +281,9 @@ def _apply_8bit_weight(

out = out.to(original_type)

if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))

if bias is not None:
out += bias

Expand All @@ -282,6 +301,11 @@ def _apply_4bit_weight(
from bitsandbytes import matmul_4bit

original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)

qweight = layer.qweight
Expand Down Expand Up @@ -310,6 +334,9 @@ def _apply_4bit_weight(

out = out.to(original_type)

if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))

if bias is not None:
out += bias

Expand Down
19 changes: 16 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,19 @@ def _get_quantized_weights_iterator(
return self._unquantized_generator(hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict

def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix)
for suffix in quantized_suffix)

def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax", "quant_map", "nested_absmax", "nested_quant_map",
"bitsandbytes"
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)

def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
Expand All @@ -912,7 +925,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):

if not weight_name.endswith((".weight", ".bias")):
if self._is_8bit_weight_name(weight_name):
continue

qweight_name = weight_name.replace(".weight", ".qweight")
Expand All @@ -932,7 +945,7 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith((".weight", ".bias")):
if not self._is_4bit_weight_name(weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
Expand All @@ -956,7 +969,7 @@ def _parse_quant_state(param_name: str,
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):

if not weight_name.endswith((".weight", ".bias")):
if self._is_4bit_weight_name(weight_name):
continue

if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
Expand Down
42 changes: 37 additions & 5 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,10 @@ def forward(self, hidden_state: torch.Tensor,
# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):

def __init__(self, config: config_mllama.MllamaVisionConfig):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()

model_parallel_size = get_tensor_model_parallel_world_size()
Expand All @@ -341,12 +344,16 @@ def __init__(self, config: config_mllama.MllamaVisionConfig):
self.head_dim,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

def forward(
Expand Down Expand Up @@ -393,7 +400,8 @@ def __init__(
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size

self.self_attn = MllamaVisionSdpaAttention(config)
self.self_attn = MllamaVisionSdpaAttention(config,
quant_config=quant_config)
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
Expand Down Expand Up @@ -1002,6 +1010,7 @@ def __init__(
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)

def forward(
Expand Down Expand Up @@ -1037,6 +1046,26 @@ def forward(
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(self,
config: config_mllama.MllamaConfig,
Expand All @@ -1061,10 +1090,13 @@ def __init__(self,
quant_config=quant_config,
prefix="language_model",
)
self.multi_modal_projector = nn.Linear(
self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
quant_config=quant_config,
gather_output=True,
prefix="multi_modal_projector",
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
Expand Down Expand Up @@ -1128,7 +1160,7 @@ def _parse_and_validate_image_input(self, **kwargs: object):
raise ValueError("No images provided.")
max_num_tiles = max(
max([len(x) for x in y[0]]) for y in pixel_values)
device = self.multi_modal_projector.weight.device
device = next(self.multi_modal_projector.parameters()).device
bsz = len(pixel_values)
out_num_tiles = []
out_images = torch.zeros(
Expand Down Expand Up @@ -1204,7 +1236,7 @@ def get_cross_attention_states(
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states, _ = self.multi_modal_projector(
cross_attention_states)

bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
Expand Down