Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Upgrade to pytorch 2.5 #9588

Merged
merged 10 commits into from
Oct 27, 2024
Merged

[Misc] Upgrade to pytorch 2.5 #9588

merged 10 commits into from
Oct 27, 2024

Conversation

bnellnm
Copy link
Contributor

@bnellnm bnellnm commented Oct 22, 2024

Upgrade to pytorch 2.5

Requires changes to flash attn: vllm-project/flash-attention#23


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@bnellnm bnellnm changed the title Upgrade to pytorch 2.5 [Misc] Upgrade to pytorch 2.5 Oct 22, 2024
@bnellnm
Copy link
Contributor Author

bnellnm commented Oct 22, 2024

/ready

@@ -4,7 +4,7 @@
# Dependencies for NVIDIA GPUs
ray >= 2.9
nvidia-ml-py # for pynvml package
torch == 2.4.0
torch == 2.5.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only concern here is now torch==2.5.0 uses the 12.4 cuda bindings by default. We might want to update the installation docs (including on the readme) to alert users that they may want to pass --extra-index-url https://download.pytorch.org/whl/cu121 during installation depending on the machine they are using

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic Does this mean that there would need to be multiple vllm packages (one for 12.1 and one for 12.4)? Or should I try to install pytorch 2.5 built with 12.1 (if such a thing exists)?

@youkaichao
Copy link
Member

@bnellnm it loos like there are some cmake errors:

[2024-10-22T15:09:27Z] #25 15.41 CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
--
  | [2024-10-22T15:09:27Z] #25 15.41 Please set them or make sure they are set and tested correctly in the CMake files:
  | [2024-10-22T15:09:27Z] #25 15.41 CUDA_CUDA_LIBRARY (ADVANCED)
  | [2024-10-22T15:09:27Z] #25 15.41     linked by target "_moe_C" in directory /workspace
  | [2024-10-22T15:09:27Z] #25 15.41     linked by target "_C" in directory /workspace
  | [2024-10-22T15:09:27Z] #25 15.41     linked by target "vllm_flash_attn_c" in directory /workspace/.deps/vllm-flash-attn-src
  | [2024-10-22T15:09:27Z] #25 15.41

the cuda version should not matter that much. I think our current pipeline should still work even if pytorch itself is built against cuda 12.4 .

@tlrmchlsmth
Copy link
Collaborator

@bnellnm it loos like there are some cmake errors:

[2024-10-22T15:09:27Z] #25 15.41 CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
--
  | [2024-10-22T15:09:27Z] #25 15.41 Please set them or make sure they are set and tested correctly in the CMake files:
  | [2024-10-22T15:09:27Z] #25 15.41 CUDA_CUDA_LIBRARY (ADVANCED)
  | [2024-10-22T15:09:27Z] #25 15.41     linked by target "_moe_C" in directory /workspace
  | [2024-10-22T15:09:27Z] #25 15.41     linked by target "_C" in directory /workspace
  | [2024-10-22T15:09:27Z] #25 15.41     linked by target "vllm_flash_attn_c" in directory /workspace/.deps/vllm-flash-attn-src
  | [2024-10-22T15:09:27Z] #25 15.41

the cuda version should not matter that much. I think our current pipeline should still work even if pytorch itself is built against cuda 12.4 .

Looks related to #8609

endif()
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB}
${CUDA_LIBRARIES})
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

pyproject.toml Outdated
@@ -6,7 +6,7 @@ requires = [
"packaging",
"setuptools>=61",
"setuptools-scm>=8.0",
"torch == 2.4.0",
"torch == 2.5.0 --extra-index-url https://download.pytorch.org/whl/cu121",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given @youkaichao's comment:

the cuda version should not matter that much. I think our current pipeline should still work even if pytorch itself is built against cuda 12.4 .

We should consider ditching the --extra-index-url. Perhaps this should be configurable, but one thing to note is that 2:4 sparse fp8 will require the Pytorch version that's built with 12.4

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks good but we should quickly come to a consensus on what to do with the CUDA version that pytorch is built against

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Oct 25, 2024

The PR looks good but we should quickly come to a consensus on what to do with the CUDA version that pytorch is built against

Im fine to remove the --extra-index-url if we decide to make 12.4 the default wheel for vllm. But we should still ship 12.1 and 11.8 wheels IMO

CMakeLists.txt Show resolved Hide resolved
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 25, 2024
@DarkLight1337
Copy link
Member

Please merge from main to fix the CI failures for multi-modal models.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excited to see it happen!

Signed-off-by: youkaichao <[email protected]>
@youkaichao
Copy link
Member

some errors are real:

cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph.

huggingface/diffusers#9704

fixing by 0068133

Signed-off-by: youkaichao <[email protected]>
@youkaichao youkaichao requested a review from ywang96 as a code owner October 27, 2024 07:36
@youkaichao
Copy link
Member

pytorch 2.5 changes the output slightly:

[2024-10-27T04:21:27Z] hf: ‘vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.\n\n- Advantages of using the LLM:\n - High-fidelity and accurate predictions: The LLM can generate high-quality and context’
[2024-10-27T04:21:27Z] vllm: ‘vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.\n\n- Advantages of using the LLM:\n - High-efficiency and low-latency inference: The LLM provides fast and efficient’

The output is still sensible.

Therefore I changed it to logprobs check instead.

For future reference, we can also change to logprobs check if exact comparison is not feasible while it is not our fault (due to pytorch or huggingface numerical change).

@youkaichao youkaichao enabled auto-merge (squash) October 27, 2024 08:23
@youkaichao youkaichao merged commit 3cb07a3 into vllm-project:main Oct 27, 2024
80 checks passed
@youkaichao youkaichao deleted the pt25 branch October 27, 2024 15:19
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: qishuai <[email protected]>
@fzyzcjy
Copy link
Contributor

fzyzcjy commented Oct 29, 2024

Hi, looking forward to this new support, I wonder when will it be released? Thanks!

mzusman added a commit to mzusman/vllm that referenced this pull request Oct 29, 2024
@youkaichao
Copy link
Member

Hi, looking forward to this new support, I wonder when will it be released? Thanks!

you can install the wheel from main branch to use it directly. see https://docs.vllm.ai/en/latest/getting_started/installation.html#install-the-latest-code

@fzyzcjy
Copy link
Contributor

fzyzcjy commented Oct 29, 2024

@youkaichao thanks! but is it less stable than a stable version?

rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: NickLucche <[email protected]>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: NickLucche <[email protected]>
mzusman added a commit to mzusman/vllm that referenced this pull request Nov 3, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
siddvenk pushed a commit to siddvenk/vllm that referenced this pull request Nov 5, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants