-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Support Pixtral models in the HF Transformers format #9036
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Hi @mgoin , thanks for your contribution! will you continue to fix the PR? |
@wuxiyiye I'm slowly working through the issues but it is quite a lot due to poor reuse of existing Llava features. I would greatly appreciate if others would have bandwidth to work on this |
Also I have verified that an FP8 checkpoint loads and produces good output: from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
model_name = "nm-testing/pixtral-12b-FP8-dynamic"
llm = LLM(
model=model_name,
max_num_seqs=1,
enforce_eager=True,
max_model_len=10000,
limit_mm_per_prompt={"image": 2}
)
image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image2 = ImageAsset("stop_sign").pil_image.convert("RGB")
inputs = {
"prompt": f"<s>[INST]Describe the images.\n[IMG][IMG][/INST]",
"multi_modal_data": {
"image": [image1, image2]
},
}
outputs = llm.generate(inputs, sampling_params=SamplingParams(temperature=0.0, max_tokens=200))
print(outputs[0].outputs[0].text)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your hard work! Some initial comments.
replace_tokens = [[processor.image_token] * num_width_tokens + | ||
[processor.image_break_token]] * num_height_tokens | ||
# Flatten list | ||
replace_tokens = [ | ||
item for sublist in replace_tokens for item in sublist | ||
] | ||
replace_tokens[-1] = processor.image_end_token | ||
replace_str = "".join(replace_tokens) | ||
replace_strings.append(replace_str) | ||
new_prompt = new_prompt.replace(processor.image_token, "<placeholder>", | ||
1) | ||
|
||
while "<placeholder>" in new_prompt: | ||
replace_str = replace_strings.pop(0) | ||
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on the prompt, this may be quite expensive. I suggest using the more optimized vllm.multimodal.utils.repeat_and_pad_placeholder_tokens
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with using repeat_and_pad_placeholder_tokens is that we need to insert image_break_token
at the end of every row and image_end_token
at the end, along with multiple different sized images in a prompt. I think we can optimize this later with a new implementation that can support this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, let's do it in another PR then. We should also TP the model in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. Please see my comment above though.
Also, we should add the HF version to our list of supported models.
@mgoin I'm trying to run |
Great work on this issue guys! However, I was wondering why "nm-testing/pixtral-12b-FP8-dynamic" is supported by vllm and "SeanScripts/pixtral-12b-nf4" (uses bitsandbytes) isn't. I get the same error as mentioned in FIX #9069 .Thoughts? Error DetailsINFO 10-22 09:09:17 config.py:1700] Downcasting torch.float32 to torch.float16. WARNING 10-22 09:09:24 config.py:361] bitsandbytes quantization is not fully optimized yet. The speed can be slower than non-quantized models. WARNING 10-22 09:09:24 config.py:435] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used INFO 10-22 09:09:24 llm_engine.py:238] Initializing an LLM engine (v0.6.3.post2.dev37+g696b01af) with config: model='SeanScripts/pixtral-12b-nf4', speculative_config=None, tokenizer='SeanScripts/pixtral-12b-nf4', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=10000, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=SeanScripts/pixtral-12b-nf4, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None) INFO 10-22 09:09:27 model_runner.py:1055] Starting to load model SeanScripts/pixtral-12b-nf4... /opt/conda/envs/prats/lib/python3.11/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. @torch.library.impl_abstract("xformers_flash::flash_fwd") /opt/conda/envs/prats/lib/python3.11/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch. @torch.library.impl_abstract("xformers_flash::flash_bwd") AttributeError Traceback (most recent call last) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/utils.py:1073, in deprecate_args..wrapper..inner(*args, **kwargs) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/entrypoints/llm.py:193, in LLM.init(self, model, tokenizer, tokenizer_mode, skip_tokenizer_init, trust_remote_code, tensor_parallel_size, dtype, quantization, revision, tokenizer_revision, seed, gpu_memory_utilization, swap_space, cpu_offload_gb, enforce_eager, max_context_len_to_capture, max_seq_len_to_capture, disable_custom_all_reduce, disable_async_output_proc, mm_processor_kwargs, task, **kwargs) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/engine/llm_engine.py:574, in LLMEngine.from_engine_args(cls, engine_args, usage_context, stat_loggers) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/engine/llm_engine.py:335, in LLMEngine.init(self, model_config, cache_config, parallel_config, scheduler_config, device_config, load_config, lora_config, speculative_config, decoding_config, observability_config, prompt_adapter_config, executor_class, log_stats, usage_context, stat_loggers, input_registry, use_cached_outputs) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/executor/executor_base.py:47, in ExecutorBase.init(self, model_config, cache_config, parallel_config, scheduler_config, device_config, load_config, lora_config, speculative_config, prompt_adapter_config, observability_config) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/executor/gpu_executor.py:40, in GPUExecutor._init_executor(self) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/worker/worker.py:180, in Worker.load_model(self) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/worker/model_runner.py:1057, in GPUModelRunnerBase.load_model(self) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/model_executor/model_loader/init.py:19, in get_model(model_config, load_config, device_config, parallel_config, scheduler_config, lora_config, cache_config) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/model_executor/model_loader/loader.py:1148, in BitsAndBytesModelLoader.load_model(self, model_config, device_config, lora_config, parallel_config, scheduler_config, cache_config) File /opt/conda/envs/prats/lib/python3.11/site-packages/vllm/model_executor/model_loader/loader.py:1033, in BitsAndBytesModelLoader._load_weights(self, model_config, model) AttributeError: Model LlavaForConditionalGeneration does not support BitsAndBytes quantization yet. |
@rebel-jonghewk Ah thanks for reporting this issue. I was going to work on making a non-xformers backend for Pixtral, but in the meantime I can at least make the import lazy to solve your issue. @pratyush0599 I'll need to look into that model checkpoint, will do. For now you should be able to use the in-flight bnb quant with the "--quantization bitsandbytes" flag |
@mgoin Hey, thanks for the prompt reply I tried using vllm serve and the in-flight quantization for original pixtral model ("mistralai/Pixtral-12B-2409") and got the same error. I tried on both models as one uses LlavafoConditionalGeneration and the other uses PixtralForConditionalGeneration but I am receiving the same error as above.This was my code.: |
…oject#9036) Signed-off-by: charlifu <[email protected]>
…oject#9036) Signed-off-by: Vinay Damodaran <[email protected]>
…oject#9036) Signed-off-by: Alvant <[email protected]>
…oject#9036) Signed-off-by: Amit Garg <[email protected]>
…oject#9036) Signed-off-by: qishuai <[email protected]>
…oject#9036) Signed-off-by: Sumit Dubey <[email protected]>
…oject#9036) Signed-off-by: Maxime Fournioux <[email protected]>
…oject#9036) Signed-off-by: Tyler Michael Smith <[email protected]>
FIX #8566
FIX #8685
FIX #9069
Introduces
PixtralHF
, which is a model implementing HF's format of Pixtral. Based off https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.pyTested with:
mistral-community/pixtral-12b
nm-testing/pixtral-12b-FP8-dynamic
This model implementation follows the Llava family, meaning image embeddings are placed instead of the
[IMG]
token placeholders. The model uses [PixtralVisionModel
] for its vision encoder, and [MistralForCausalLM
] for its language decoder.Example output from
python examples/offline_inference_vision_language.py --model pixtral_hf
:Offline multi-image example
Script used for simple testing of multi-image:
Output:
Offline chat example
Script used for testing of chat templating:
Output: