diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index faa8d92e83de3..7a039a78f09b8 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -3,6 +3,7 @@ import torch from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -23,7 +24,7 @@ def __init__( bnb_4bit_use_double_quant: bool = False, llm_int8_enable_fp32_cpu_offload: bool = False, llm_int8_has_fp16_weight: bool = False, - llm_int8_skip_modules: Optional[Any] = None, + llm_int8_skip_modules: Optional[List[str]] = None, llm_int8_threshold: float = 0.0, ) -> None: @@ -34,11 +35,15 @@ def __init__( self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight - self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_skip_modules = llm_int8_skip_modules or [] self.llm_int8_threshold = llm_int8_threshold def __repr__(self) -> str: - return "BitsAndBytesConfig" + return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " + f"load_in_4bit={self.load_in_4bit}, " + f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " + f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " + f"llm_int8_skip_modules={self.llm_int8_skip_modules})") @classmethod def get_name(self) -> str: @@ -102,8 +107,10 @@ def get_safe_value(config, keys, default_value=None): llm_int8_threshold=llm_int8_threshold) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["BitsAndBytesLinearMethod"]: + prefix: str) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): + if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules): + return UnquantizedLinearMethod() return BitsAndBytesLinearMethod(self) return None @@ -111,6 +118,10 @@ def get_scaled_act_names(self) -> List[str]: return [] +def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): + return any(module_name in prefix for module_name in llm_int8_skip_modules) + + class BitsAndBytesLinearMethod(LinearMethodBase): """Linear method for BitsAndBytes. @@ -211,6 +222,11 @@ def _apply_8bit_weight( from bitsandbytes import MatmulLtState, matmul original_type = x.dtype + original_shape = x.shape + reshape_after_matmul = False + if x.ndim > 2: + x = x.reshape(-1, x.size(-1)) + reshape_after_matmul = True bf_x = x.to(torch.bfloat16) qweight = layer.qweight @@ -265,6 +281,9 @@ def _apply_8bit_weight( out = out.to(original_type) + if reshape_after_matmul: + out = out.view(*original_shape[:-1], out.size(-1)) + if bias is not None: out += bias @@ -282,6 +301,11 @@ def _apply_4bit_weight( from bitsandbytes import matmul_4bit original_type = x.dtype + original_shape = x.shape + reshape_after_matmul = False + if x.ndim > 2: + x = x.reshape(-1, x.size(-1)) + reshape_after_matmul = True bf_x = x.to(torch.bfloat16) qweight = layer.qweight @@ -310,6 +334,9 @@ def _apply_4bit_weight( out = out.to(original_type) + if reshape_after_matmul: + out = out.view(*original_shape[:-1], out.size(-1)) + if bias is not None: out += bias diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 813f58339da37..3cfee13b9fa6e 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -899,6 +899,19 @@ def _get_quantized_weights_iterator( return self._unquantized_generator(hf_weights_files, use_safetensors, quant_state_dict), quant_state_dict + def _is_8bit_weight_name(self, weight_name: str): + quantized_suffix = {".scb", ".weight_format"} + return any(weight_name.lower().endswith(suffix) + for suffix in quantized_suffix) + + def _is_4bit_weight_name(self, weight_name: str): + quantized_suffix = { + "absmax", "quant_map", "nested_absmax", "nested_quant_map", + "bitsandbytes" + } + suffix = weight_name.split(".")[-1] + return any(q_suffix in suffix for q_suffix in quantized_suffix) + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: for weight_name, weight_tensor in self._hf_weight_iter( @@ -912,7 +925,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): - if not weight_name.endswith((".weight", ".bias")): + if self._is_8bit_weight_name(weight_name): continue qweight_name = weight_name.replace(".weight", ".qweight") @@ -932,7 +945,7 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, use_safetensors) temp_state_dict = {} for weight_name, weight_tensor in weight_iterator: - if weight_name.endswith((".weight", ".bias")): + if not self._is_4bit_weight_name(weight_name): continue # bitsandbytes library requires # weight.quant_state.bitsandbytes__* in CPU @@ -956,7 +969,7 @@ def _parse_quant_state(param_name: str, for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): - if not weight_name.endswith((".weight", ".bias")): + if self._is_4bit_weight_name(weight_name): continue if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 44ef49729c969..5cf5272cae878 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -325,7 +325,10 @@ def forward(self, hidden_state: torch.Tensor, # TODO: support other attention backends for attention in vision model class MllamaVisionSdpaAttention(nn.Module): - def __init__(self, config: config_mllama.MllamaVisionConfig): + def __init__(self, + config: config_mllama.MllamaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() model_parallel_size = get_tensor_model_parallel_world_size() @@ -341,12 +344,16 @@ def __init__(self, config: config_mllama.MllamaVisionConfig): self.head_dim, self.num_heads, bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, self.embed_dim, bias=False, input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) def forward( @@ -393,7 +400,8 @@ def __init__( self.is_gated = is_gated self.intermediate_size = config.intermediate_size - self.self_attn = MllamaVisionSdpaAttention(config) + self.self_attn = MllamaVisionSdpaAttention( + config, quant_config=quant_config, prefix=f"{prefix}.self_attn") self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") @@ -1002,6 +1010,7 @@ def __init__( org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=f"{prefix}.lm_head", ) def forward( @@ -1037,6 +1046,26 @@ def forward( @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } def __init__(self, config: config_mllama.MllamaConfig, @@ -1061,10 +1090,13 @@ def __init__(self, quant_config=quant_config, prefix="language_model", ) - self.multi_modal_projector = nn.Linear( + self.multi_modal_projector = ColumnParallelLinear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=True, + quant_config=quant_config, + gather_output=True, + prefix="multi_modal_projector", ) self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) @@ -1128,7 +1160,7 @@ def _parse_and_validate_image_input(self, **kwargs: object): raise ValueError("No images provided.") max_num_tiles = max( max([len(x) for x in y[0]]) for y in pixel_values) - device = self.multi_modal_projector.weight.device + device = next(self.multi_modal_projector.parameters()).device bsz = len(pixel_values) out_num_tiles = [] out_images = torch.zeros( @@ -1204,7 +1236,7 @@ def get_cross_attention_states( cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) - cross_attention_states = self.multi_modal_projector( + cross_attention_states, _ = self.multi_modal_projector( cross_attention_states) bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)