Skip to content

Commit

Permalink
[bitsandbytes]: support read bnb pre-quantized model such as lllyasvi…
Browse files Browse the repository at this point in the history
…el/omost-llama-3-8b-4bits
  • Loading branch information
thesues committed Jun 26, 2024
1 parent f1e72cc commit 3377572
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 38 deletions.
13 changes: 10 additions & 3 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams

models_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
]


@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
def test_load_bnb_model(vllm_runner) -> None:
with vllm_runner('huggyllama/llama-7b',
@pytest.mark.parametrize("model_name, description", models_to_test)
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True) as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501

# check the weights in MLP & SelfAttention are quantized to torch.uint8
Expand Down Expand Up @@ -72,5 +77,7 @@ def test_load_bnb_model(vllm_runner) -> None:
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]
#truncate actual_output
actual_output = actual_output[:len(expected_output)]
assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ def create_engine_config(self, ) -> EngineConfig:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
Expand Down
25 changes: 4 additions & 21 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314
"""

def __init__(
self,
adapter_name_or_path: str,
target_modules: List[str],
) -> None:

self.adapter_name_or_path = adapter_name_or_path
self.target_modules = target_modules
def __init__(self, ) -> None:
pass

def __repr__(self) -> str:
return (
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
)
return ("BitsAndBytesConfig")

@classmethod
def get_name(self) -> str:
Expand All @@ -49,16 +41,7 @@ def get_config_filenames() -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"])
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
if adapter_name == "":
target_modules = default_target_modules
else:
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)
return cls()

def get_quant_method(
self,
Expand Down
88 changes: 76 additions & 12 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,14 @@ def _prepare_weights(self, model_name_or_path: str,

return hf_weights_files, matched_pattern == "*.safetensors"

def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
else:
return pt_weights_iterator(hf_weights_files)

def _get_quantized_weights_iterator(
self, model_name_or_path: str, revision: Optional[str]
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
Expand All @@ -691,6 +697,7 @@ def _get_quantized_weights_iterator(
# only load the bitsandbytes module when needed
try:
import bitsandbytes
from bitsandbytes.functional import QuantState
if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.")
Expand All @@ -704,17 +711,63 @@ def _get_quantized_weights_iterator(
model_name_or_path, revision)

quant_state_dict = {}
if use_safetensors:
weight_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weight_iterator = pt_weights_iterator(hf_weights_files)

def generator():
def quantized_checkpoint() -> Generator:
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
continue
# TODO: only nf4 quantization is supported for now
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
raise NotImplementedError(
"Only bitsandbytes_nf4 quantization"\
"is supported right now."
)
temp_state_dict[weight_name] = weight_tensor

# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: Dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state[param_name +
".quant_state.bitsandbytes__nf4"] = quant_state[
param_name +
".quant_state.bitsandbytes__nf4"].cpu().data
return QuantState.from_dict(quant_state, device="cuda")

# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
continue
if weight_name + ".quant_state.bitsandbytes__nf4" \
in temp_state_dict:
quant_state = _parse_quant_state(weight_name,
temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight",
".qweight"), weight_tensor
else:
yield weight_name, weight_tensor

def generator() -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# bitsandbytes requires data in GPU
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
Expand All @@ -728,6 +781,8 @@ def generator():

yield weight_name, processed_weight

if pre_quant:
return quantized_checkpoint(), quant_state_dict
return generator(), quant_state_dict

def _load_weights(self, model_config: ModelConfig,
Expand All @@ -745,12 +800,21 @@ def _load_weights(self, model_config: ModelConfig,
logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...")

qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision))
is_quantized_checkpoint = False
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if quant_config is not None and quant_config.get(
'quant_method') == "bitsandbytes":
is_quantized_checkpoint = True

qweight_iterator, quant_state_dict = \
self._get_quantized_weights_iterator(
model_config.model, model_config.revision, is_quantized_checkpoint)

model.load_weights(qweight_iterator)

torch.cuda.empty_cache()

param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
for quant_param_name in quant_state_dict:
Expand Down Expand Up @@ -788,9 +852,9 @@ def _load_weights(self, model_config: ModelConfig,
f"pack_factor not set for parameter {param_name}.")

num_elements = [0] * len(quant_states)
for seq, quant_state in enumerate(quant_states.items()):
for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod(
quant_state[1].shape) // pack_ratio
quant_state.shape) // pack_ratio

offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig:

quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
Expand Down

0 comments on commit 3377572

Please sign in to comment.