Skip to content

Commit

Permalink
[Bugfix] bitsandbytes models fail to run pipeline parallel (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#10200)

Signed-off-by: Hoang Cong Duc <[email protected]>
  • Loading branch information
HoangCongDuc authored Nov 13, 2024
1 parent 0b8bb86 commit ac49b59
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
30 changes: 29 additions & 1 deletion tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

from tests.quantization.utils import is_quant_method_supported
from tests.utils import fork_new_process_for_each_test
from tests.utils import compare_two_settings, fork_new_process_for_each_test

models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"),
Expand Down Expand Up @@ -82,6 +82,34 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
vllm_tp_size=2)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='Test requires at least 2 GPUs.')
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@fork_new_process_for_each_test
def test_load_pp_4bit_bnb_model(model_name, description) -> None:
common_args = [
"--disable-log-stats",
"--disable-log-requests",
"--dtype",
"bfloat16",
"--enable-prefix-caching",
"--quantization",
"bitsandbytes",
"--load-format",
"bitsandbytes",
"--gpu-memory-utilization",
"0.7",
]
pp_args = [
*common_args,
"--pipeline-parallel-size",
"2",
]
compare_two_settings(model_name, common_args, pp_args)


def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,13 @@ def _load_weights(self, model_config: ModelConfig,

param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue

non_stacked_param_name = quant_param_name

shard_index = 0
Expand Down

0 comments on commit ac49b59

Please sign in to comment.