Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add BNB quantization support for Mllama #9720

Merged
merged 10 commits into from
Oct 29, 2024

Conversation

Isotr0py
Copy link
Collaborator

FILL IN THE PR DESCRIPTION HERE

FIX #9714 (link existing issues this PR will resolve)

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@Isotr0py Isotr0py marked this pull request as ready for review October 27, 2024 12:16
@Isotr0py Isotr0py requested a review from mgoin October 27, 2024 12:16
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, have you tested loading a model on HF?

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 28, 2024
@Isotr0py
Copy link
Collaborator Author

Sure. I ran both unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit and Isotr0py/Llama-3.2-11B-Vision-Instruct-bnb-8bit with full/partial quantization. Both model can be loaded and generate reasonable outputs:

Llama-3.2-11B-Vision-Instruct-bnb-4bit
INFO 10-29 07:02:53 llm_engine.py:240] Initializing an LLM engine (v0.1.dev3049+g82c2515.d20241019) with config: model='unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit', speculative_config=None, tokenizer='unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=False, use_cached_outputs=False, chat_template_text_format=string, mm_processor_kwargs=None)
INFO 10-29 07:02:54 enc_dec_model_runner.py:141] EncoderDecoderModelRunner requires XFormers backend; overriding backend auto-selection and forcing XFormers.
INFO 10-29 07:02:54 selector.py:119] Using XFormers backend.
INFO 10-29 07:03:04 model_runner.py:1055] Starting to load model unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit...
INFO 10-29 07:03:05 selector.py:119] Using XFormers backend.
INFO 10-29 07:03:05 loader.py:1064] Loading weights with BitsAndBytes quantization.  May take a while ...
INFO 10-29 07:03:05 weight_utils.py:243] Using model weights format ['*.safetensors']
model-00002-of-00002.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.18G/2.18G [00:52<00:00, 41.9MB/s]
model-00001-of-00002.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5.00G/5.00G [01:59<00:00, 41.9MB/s]
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]███████████████████████████████████████████████████████████████████████████████| 5.00G/5.00G [01:59<00:00, 42.1MB/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.82s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.16s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.26s/it]

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:02<00:02,  2.37s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.47s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.60s/it]

INFO 10-29 07:05:11 model_runner.py:1066] Loading model weights took 6.7244 GB
INFO 10-29 07:05:11 enc_dec_model_runner.py:301] Starting profile run for multi-modal models.
INFO 10-29 07:05:52 worker.py:260] Memory profiling results: total_gpu_memory=14.74GiB initial_memory_usage=6.88GiB peak_torch_memory=7.74GiB memory_usage_post_profile=6.93Gib non_torch_memory=0.20GiB kv_cache_size=5.33GiB gpu_memory_utilization=0.90
INFO 10-29 07:05:52 gpu_executor.py:122] # GPU blocks: 2183, # CPU blocks: 1638
INFO 10-29 07:05:52 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 8.53x
WARNING 10-29 07:05:59 preprocess.py:89] Falling back on <BOS> for decoder start token id because decoder start token id is not available.
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:58<00:00, 14.75s/it, est. speed input: 0.68 toks/s, output: 4.34 toks/s]
 The image depicts a tall tower with a white dome and a spire, surrounded by pink cherry blossom trees. The tower is in the center of the image, and the blossoms are in the foreground. The sky is blue and clear, suggesting that it is spring or early summer. The overall atmosphere of the image is
 The image is a photograph of a white tower framed by cherry blossoms. The tower is in the center of the image and is framed by branches of cherry blossoms. The tower is tall and white with a pointed top. The cherry blossoms are pink and are in full bloom. The background is a clear blue sky
 The image depicts a tall tower, likely a skyscraper or a tower of some kind, with a white exterior and a pointed top, surrounded by a tree with pink flowers. The tree is in the foreground of the image, and the tower is in the background. The sky is blue and clear, suggesting that the photo
 The image shows a tall white tower, framed by branches of a cherry blossom tree. The tower is in the center of the image, and it is surrounded by pink flowers. The sky is blue and clear. The overall atmosphere suggests a spring or summer day, with the cherry blossoms in full bloom. The image is
Llama-3.2-11B-Vision-Instruct-bnb-8bit
WARNING 10-29 13:47:25 config.py:1707] Casting torch.bfloat16 to torch.float16.
WARNING 10-29 13:47:29 config.py:364] bitsandbytes quantization is not fully optimized yet. The speed can be slower than non-quantized models.
WARNING 10-29 13:47:29 config.py:438] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-29 13:47:29 llm_engine.py:240] Initializing an LLM engine (vdev) with config: model='/root/autodl-tmp/Llama-3.2-11B-Vision-Instruct-bnb-8bit', speculative_config=None, tokenizer='/root/autodl-tmp/Llama-3.2-11B-Vision-Instruct-bnb-8bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/root/autodl-tmp/Llama-3.2-11B-Vision-Instruct-bnb-8bit, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=False, use_cached_outputs=False, chat_template_text_format=string, mm_processor_kwargs=None)
INFO 10-29 13:47:30 enc_dec_model_runner.py:141] EncoderDecoderModelRunner requires XFormers backend; overriding backend auto-selection and forcing XFormers.
INFO 10-29 13:47:30 selector.py:119] Using XFormers backend.
/root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
/root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
INFO 10-29 13:47:31 model_runner.py:1055] Starting to load model /root/autodl-tmp/Llama-3.2-11B-Vision-Instruct-bnb-8bit...
INFO 10-29 13:47:31 selector.py:119] Using XFormers backend.
INFO 10-29 13:47:31 loader.py:1064] Loading weights with BitsAndBytes quantization.  May take a while ...
Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:00<00:01,  1.70it/s]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:02<00:01,  1.19s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:03<00:00,  1.38s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:03<00:00,  1.27s/it]

Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:00<00:01,  1.72it/s]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:02<00:01,  1.22s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:03<00:00,  1.40s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:03<00:00,  1.29s/it]

INFO 10-29 13:47:39 model_runner.py:1066] Loading model weights took 11.0580 GB
INFO 10-29 13:47:39 enc_dec_model_runner.py:301] Starting profile run for multi-modal models.
/root/miniconda3/lib/python3.12/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
INFO 10-29 13:47:51 worker.py:260] Memory profiling results: total_gpu_memory=31.73GiB initial_memory_usage=11.50GiB peak_torch_memory=13.64GiB memory_usage_post_profile=13.43Gib non_torch_memory=1.07GiB kv_cache_size=13.85GiB gpu_memory_utilization=0.90
INFO 10-29 13:47:51 gpu_executor.py:122] # GPU blocks: 5671, # CPU blocks: 1638
INFO 10-29 13:47:51 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 22.15x
WARNING 10-29 13:47:54 preprocess.py:89] Falling back on <BOS> for decoder start token id because decoder start token id is not available.
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [01:18<00:00, 19.75s/it, est. speed input: 0.51 toks/s, output: 3.24 toks/s]
 The image shows a cherry blossom tree in front of a tall tower. The tree is in full bloom, with pink flowers covering its branches. The tower is white and has a distinctive shape, with a large dome at the top and a series of smaller domes below it. The sky is blue and clear, suggesting that
 The image is a photograph of a white tower framed by cherry blossoms. The tower is in the center of the image and is tall and slender. It has a round top and a narrow, tall, white, metal lattice structure. The tower is framed by branches of a blooming cherry tree. The branches are in
 The image is a photograph of a white tower surrounded by pink cherry blossoms. The tower is in the center of the image and is tall and slender, with a rounded top. It has a series of windows and a spire at the top. The tower is surrounded by branches of a cherry blossom tree that are covered
 The image depicts a tall white tower, possibly a skyscraper or a tower, surrounded by pink cherry blossoms. The tower is in the center of the image and is surrounded by branches of pink cherry blossoms. The sky is blue and clear, suggesting that it is a sunny day. The overall atmosphere of the image

@mgoin mgoin merged commit 09500f7 into vllm-project:main Oct 29, 2024
63 checks passed
@Isotr0py Isotr0py deleted the mllama-bnb branch October 29, 2024 12:22
rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
@luizhsalazar
Copy link

luizhsalazar commented Nov 5, 2024

Is this feature available in which version? or is still in dev?

I'm facing the error AttributeError: Model MllamaForConditionalGeneration does not support BitsAndBytes quantization yet when try to run quantized model with the command:

docker run --runtime nvidia --gpus all
-v ~/.cache/huggingface:/root/.cache/huggingface
--env "HUGGING_FACE_HUB_TOKEN="
-p 8000:8000
--ipc=host
vllm/vllm-openai:v0.6.3.post1
--model unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit
--dtype half
--quantization bitsandbytes
--load_format bitsandbytes
--max_model_len 50000
--gpu_memory_utilization 0.99
--trust-remote-code
--enforce-eager

JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: AttributeError: Model MllamaForConditionalGeneration does not support BitsAndBytes quantization yet
3 participants