Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Quick fix to make Pixtral-HF load correctly again after 39e227c7ae. #11024

Merged
merged 3 commits into from
Dec 12, 2024

Conversation

sjuxax
Copy link
Contributor

@sjuxax sjuxax commented Dec 9, 2024

After 39e227c, I was unable to start my Pixtral-HF model due to this traceback

Traceback
INFO 12-09 08:47:15 model_runner.py:1094] Loading model weights took 8.2663 GB
ERROR 12-09 08:47:16 engine.py:366] vllm.multimodal.inputs.MultiModalKwargs() got multiple values for keyword argument 'is_pixtral'
ERROR 12-09 08:47:16 engine.py:366] Traceback (most recent call last):
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
ERROR 12-09 08:47:16 engine.py:366]     engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
ERROR 12-09 08:47:16 engine.py:366]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
ERROR 12-09 08:47:16 engine.py:366]     return cls(ipc_path=ipc_path,
ERROR 12-09 08:47:16 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 71, in __init__
ERROR 12-09 08:47:16 engine.py:366]     self.engine = LLMEngine(*args, **kwargs)
ERROR 12-09 08:47:16 engine.py:366]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 291, in __init__
ERROR 12-09 08:47:16 engine.py:366]     self._initialize_kv_caches()
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 430, in _initialize_kv_caches
ERROR 12-09 08:47:16 engine.py:366]     self.model_executor.determine_num_available_blocks())
ERROR 12-09 08:47:16 engine.py:366]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/executor/gpu_executor.py", line 68, in determine_num_available_blocks
ERROR 12-09 08:47:16 engine.py:366]     return self.driver_worker.determine_num_available_blocks()
ERROR 12-09 08:47:16 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 12-09 08:47:16 engine.py:366]     return func(*args, **kwargs)
ERROR 12-09 08:47:16 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/worker/worker.py", line 199, in determine_num_available_blocks
ERROR 12-09 08:47:16 engine.py:366]     self.model_runner.profile_run()
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 12-09 08:47:16 engine.py:366]     return func(*args, **kwargs)
ERROR 12-09 08:47:16 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1285, in profile_run
ERROR 12-09 08:47:16 engine.py:366]     .dummy_data_for_profiling(self.model_config,
ERROR 12-09 08:47:16 engine.py:366]      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/inputs/registry.py", line 249, in dummy_data_for_profiling
ERROR 12-09 08:47:16 engine.py:366]     dummy_data = processor.get_dummy_data(seq_len, mm_counts,
ERROR 12-09 08:47:16 engine.py:366]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/multimodal/processing.py", line 795, in get_dummy_data
ERROR 12-09 08:47:16 engine.py:366]     multi_modal_data=self._get_dummy_mm_kwargs(mm_counts),
ERROR 12-09 08:47:16 engine.py:366]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366]   File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/model_executor/models/llava.py", line 223, in _get_dummy_mm_kwargs
ERROR 12-09 08:47:16 engine.py:366]     return MultiModalKwargs(
ERROR 12-09 08:47:16 engine.py:366]            ^^^^^^^^^^^^^^^^^
ERROR 12-09 08:47:16 engine.py:366] TypeError: vllm.multimodal.inputs.MultiModalKwargs() got multiple values for keyword argument 'is_pixtral'
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 368, in run_mp_engine
    raise e
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
    return cls(ipc_path=ipc_path,
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 71, in __init__
    self.engine = LLMEngine(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 291, in __init__
    self._initialize_kv_caches()
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 430, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/executor/gpu_executor.py", line 68, in determine_num_available_blocks
    return self.driver_worker.determine_num_available_blocks()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/worker/worker.py", line 199, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1285, in profile_run
    .dummy_data_for_profiling(self.model_config,
     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/inputs/registry.py", line 249, in dummy_data_for_profiling
    dummy_data = processor.get_dummy_data(seq_len, mm_counts,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/multimodal/processing.py", line 795, in get_dummy_data
    multi_modal_data=self._get_dummy_mm_kwargs(mm_counts),
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/model_executor/models/llava.py", line 223, in _get_dummy_mm_kwargs
    return MultiModalKwargs(
           ^^^^^^^^^^^^^^^^^
TypeError: vllm.multimodal.inputs.MultiModalKwargs() got multiple values for keyword argument 'is_pixtral'
[rank0]:[W1209 08:47:17.951197986 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 683, in <module>
    uvloop.run(run_server(args))
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/uvloop/__init__.py", line 109, in run
    return __asyncio.run(
           ^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
           ^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 649, in run_server
    async with build_async_engine_client(args) as engine_client:
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 116, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/.virtualenvs/vllm312/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 213, in build_async_engine_client_from_engine_args
    raise RuntimeError(
RuntimeError: Engine process failed to start. See stack trace for the root cause.

Quick search didn't show any open issues reporting this error yet.

This is a quick minimal fix that ensures is_pixtral only appears in the hf_inputs once, adding it only in cases where it's not already into hf_inputs by the preprocess function (as currently done in

hf_inputs["is_pixtral"] = torch.tensor(True)
).

This gets my model started again and seems to be working fine.

Copy link

github-actions bot commented Dec 9, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

Sorry for breaking the model! I think a better solution would be to move this function into the LlavaProcessor and use the patched _get_hf_processor method to get the correct image preprocessing.

@comaniac
Copy link
Collaborator

Got the same error and found this PR. We should get this in ASAP

Remove the extra check ensuring `is_pixtral` is available; apparently reasonable confidence that it'll be there on all relevant codepaths 🤷🏻

Co-authored-by: Isotr0py <[email protected]>
Copy link
Collaborator

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 12, 2024
@comaniac comaniac enabled auto-merge (squash) December 12, 2024 16:42
@comaniac comaniac merged commit 5d71257 into vllm-project:main Dec 12, 2024
65 checks passed
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants