diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 39965ac9115c2..6a0de3034142a 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -20,6 +20,7 @@ def __init__( load_in_8bit: bool = False, load_in_4bit: bool = True, bnb_4bit_compute_dtype: str = "float32", + bnb_4bit_quant_storage: str = "uint8", bnb_4bit_quant_type: str = "fp4", bnb_4bit_use_double_quant: bool = False, llm_int8_enable_fp32_cpu_offload: bool = False, @@ -31,6 +32,7 @@ def __init__( self.load_in_8bit = load_in_8bit self.load_in_4bit = load_in_4bit self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage self.bnb_4bit_quant_type = bnb_4bit_quant_type self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload @@ -38,10 +40,15 @@ def __init__( self.llm_int8_skip_modules = llm_int8_skip_modules or [] self.llm_int8_threshold = llm_int8_threshold + if self.bnb_4bit_quant_storage not in ["uint8"]: + raise ValueError("Unsupported bnb_4bit_quant_storage: " + f"{self.bnb_4bit_quant_storage}") + def __repr__(self) -> str: return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " f"load_in_4bit={self.load_in_4bit}, " f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " + f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " f"llm_int8_skip_modules={self.llm_int8_skip_modules})") @@ -80,6 +87,9 @@ def get_safe_value(config, keys, default_value=None): bnb_4bit_compute_dtype = get_safe_value(config, ["bnb_4bit_compute_dtype"], default_value="float32") + bnb_4bit_quant_storage = get_safe_value(config, + ["bnb_4bit_quant_storage"], + default_value="uint8") bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], default_value="fp4") bnb_4bit_use_double_quant = get_safe_value( @@ -99,6 +109,7 @@ def get_safe_value(config, keys, default_value=None): load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_quant_storage=bnb_4bit_quant_storage, bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,