Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support offloading KV cache to CPU #9682

Closed
wants to merge 17 commits into from

Conversation

KuntaiDu
Copy link
Collaborator

@KuntaiDu KuntaiDu commented Oct 25, 2024

A minmal implementation for CPU KV cache offloading (#7697)

Benchmarking results:

A long document QA workload (see google doc for more discriptions). GPU can cache 10 documents and CPU can cache 40 documents.
image
CPU offloading is better when GPU space is not enough to cache all documents but CPU can.

google doc link

Implementation

This PR has much less features compared to #8694, but it is really minimum and creates very little core change. So I guess we can use this PR to enable CPU KV cache offloading first, and then focus on disk.

The key idea of this implementation is to maintain those allocated blocks that didn't hit the cache, and constantly copy them into CPU after each scheduler step.

Here is the flow diagram
image

This idea is borrowed from ConServe (paper link: https://arxiv.org/abs/2410.01228), based on the assumption that the CPU-GPU bandwidth is much higher than GPU KV cache generation throughput. Thanks Yifan for this idea.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@KuntaiDu KuntaiDu marked this pull request as draft October 25, 2024 05:22
@mergify mergify bot added the frontend label Oct 28, 2024
Copy link

mergify bot commented Oct 28, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @KuntaiDu please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 28, 2024
@mergify mergify bot removed the needs-rebase label Oct 28, 2024
@KuntaiDu KuntaiDu marked this pull request as ready for review October 28, 2024 21:05
@ClarkChin08
Copy link

ClarkChin08 commented Oct 31, 2024

@KuntaiDu I tested this PR on A100 GPU and it will have listed issues:
[rank0]:[W1030 22:42:25.307407705 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())

After this warning, the process will lockup and machine should restart!
kernel:[67924.212374] watchdog: BUG: soft lockup - CPU#101 stuck for 22s! [python:187618]

@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Nov 5, 2024

We also observe similar issue.

@KuntaiDu I tested this PR on A100 GPU and it will have listed issues: [rank0]:[W1030 22:42:25.307407705 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())

After this warning, the process will lockup and machine should restart! kernel:[67924.212374] watchdog: BUG: soft lockup - CPU#101 stuck for 22s! [python:187618]

Oh weird .... I didn't touch distributed initialization and destroying part. This should not be the case >.< Let me try to reproduce.

BTW, I am also working on a more performant cuda kernel for CPU-GPU memcpy, current memcpy kernel is ... really slow.

@ClarkChin08
Copy link

We also observe similar issue.

@KuntaiDu I tested this PR on A100 GPU and it will have listed issues: [rank0]:[W1030 22:42:25.307407705 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
After this warning, the process will lockup and machine should restart! kernel:[67924.212374] watchdog: BUG: soft lockup - CPU#101 stuck for 22s! [python:187618]

Oh weird .... I didn't touch distributed initialization and destroying part. This should not be the case >.< Let me try to reproduce.

BTW, I am also working on a more performant cuda kernel for CPU-GPU memcpy, current memcpy kernel is ... really slow.

memcpy between CPU-GPU is quite important for the latency of data load/store, do you have more detailed information on this? How about the bandwith and the time consumed as the data increase?

@zachzzc
Copy link
Contributor

zachzzc commented Nov 7, 2024

Clean implementation! Just to verify my understanding, so some data will have copies in both GPU and CPU?
And do we still need to overlap the memory copies with the computation

@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Nov 19, 2024

Clean implementation! Just to verify my understanding, so some data will have copies in both GPU and CPU? And do we still need to overlap the memory copies with the computation

  1. Yes, which is required to support preemption with recomputation (while still hitting KV cache in CPU)
  2. Yes (but not in this PR)

@yyccli
Copy link
Contributor

yyccli commented Nov 28, 2024

Hi
It seems that there is a bug when you try to extend the mapping list returned from the function get_and_reset_swaps to blocks_to_swap_out and blocks_to_swap_in lists. The index in blocks_to_swap_out/in lists should be the zero-offset block id on certain device instead of the absolute global block id, but index in the returned mapping list use global block id.

@ApostaC ApostaC requested a review from tlrmchlsmth as a code owner December 3, 2024 06:21
Copy link

mergify bot commented Dec 3, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @KuntaiDu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Dec 3, 2024

Need to solve DCO issue in PR #10874 , so I close this PR.

@KuntaiDu KuntaiDu closed this Dec 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants