Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support offloading KV cache to CPU #10874

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

ApostaC
Copy link

@ApostaC ApostaC commented Dec 3, 2024

An implementation for CPU KV cache offloading (#7697)

TL; DR: CPU offloading is better than prefix caching in our benchmark, we also found that the evictor can be optimized to save 10-30% of the runtime.

This PR is for fixing the DCO issue for the Kuntai's original CPU offloading PR #9682 . It also contains new CUDA kernels to improve the KV cache offloading performance.

End-to-end benchmarking results:

A long document QA workload (see benchmarks/benchmark_long_document_qa.py) running on A100-40G-SXM GPU. The GPU can cache 8 documents and the CPU can cache 30 documents.

image

(Following are the original data for the above figure)

Num documents vLLM vLLM w/ prefix caching vLLM w/ prefix caching + CPU offloading
8 13.66 0.49 0.5
16 27.28 7.22 2.3
32 54.54 49.96 17.26
64 109.27 126.08 110.96

New kernel implementation microbenchmark

The numbers are collected on A100-40GB-SXM GPUs

# of pages New implementation Old implementation
Swap in 1 page 1.21 GB / second 1.21 GB / second
Swap in 1250 pages 12.6 GB / second 3.11 GB / second
Swap out 1 page 1.22 GB / second 1.21 GB / second
Swap out 1250 pages 12.5 GB / second 3.12 GB / seocond

The new kernel can achieve 4x better throughput than the old swap_block implementation.
Also, it won't decrease the performance when the number of pages are small.

Potential improvement:

Currently, the swap_block is invoked once per layer. If we can aggregate the copy of all the layers into one kernel, the throughput of copying 1 page will also achieve >10GB/s.

Implementation

This PR has much less features compared to #8694, but it is really minimum and creates very little core change. So I guess we can use this PR to enable CPU KV cache offloading first, and then focus on disk.

The key idea of this implementation is to maintain those allocated blocks that didn't hit the cache, and constantly copy them into CPU after each scheduler step.

Here is the flow diagram
image

This idea is borrowed from ConServe (paper link: https://arxiv.org/abs/2410.01228), based on the assumption that the CPU-GPU bandwidth is much higher than GPU KV cache generation throughput. Thanks Yifan for this idea.

Copy link

github-actions bot commented Dec 3, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

ApostaC and others added 2 commits December 3, 2024 20:59
@ApostaC ApostaC force-pushed the yihua-cpu-offloading2 branch from da35ed9 to e6654f2 Compare December 3, 2024 21:00
@@ -362,7 +362,7 @@ def test_swap_blocks(
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
device=device).view(-1, 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because that this tensor need to be accessed by the new CUDA memcpy kernel?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The new paged_copy kernel need to access the block mapping from GPU.

@KuntaiDu KuntaiDu added ready ONLY add when PR is ready to merge/full CI is needed and removed ready ONLY add when PR is ready to merge/full CI is needed labels Dec 5, 2024
@@ -508,3 +523,19 @@ def get_num_cached_tokens(self, seq: Sequence) -> int:
cached in the block manager for the sequence.
"""
return self._computed_blocks_tracker.get_num_cached_tokens(seq)

def get_and_reset_swaps(self,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems like not get the real physical block ID from get_physical_block_id? Especially for CPU PrefixCachingBlockAllocator, whose start ID is not zero.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this function should not return the physical block IDs because the get_physical_block_id will be called in block_manager.swap_in()/block_manager.swap_out() later.

Call chain is: scheduler._swap_in() --> block_manager.swap_in() --> block_allocator.get_physical_block_id() (similar for swapping out).

(Let me know if my understanding is incorrect and I will fix it asap, thanks!)


# NOTE(Kuntai): extend the swapping list for CPU offloading
new_swap_out, new_swap_in = \
self.block_manager.get_and_reset_swaps(time.time())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, the get_and_reset_swaps is called directly here without get_physical_block_id. I think these block IDs are sent to the cache engine directly later, so not the real physical block IDs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! I double-checked the logic and you are right. Just pushed another commit to fix the issue and update the docstring. Thanks for the catch!

Copy link

mergify bot commented Dec 9, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ApostaC.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 9, 2024
@mergify mergify bot removed the needs-rebase label Dec 9, 2024
@KuntaiDu KuntaiDu added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 11, 2024
uncached: allocated blocks that didn't hit any cache
cached: allocated blocks that are cached, either in GPU or in CPU
free: the blocks are not allocated by block allocator
This implementation aims to transform uncacherd blocks to cached blocks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uncacherd blocks -> uncached blocks?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! Fixed!

@ApostaC ApostaC force-pushed the yihua-cpu-offloading2 branch 2 times, most recently from 583984f to 789b00e Compare December 13, 2024 16:41
@ApostaC ApostaC force-pushed the yihua-cpu-offloading2 branch from 9fcf23e to 2648fa5 Compare December 13, 2024 17:12
Comment on lines +417 to +427
parser.add_argument(
'--block-allocator',
type=str,
default='CpuGpuBlockAllocator',
choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'],
help='The block allocator that vLLM uses. Currently'
' can be CpuGpuBlockAllocator (the default) and '
'CpuOffloadingBlockAllocator (experimental) that '
'supports offloading the KV cache to CPU . '
'When using CpuOffloadingBlockAllocator, the '
'preemption mode must be recompute.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid to expose block allocator? Instead we should provide something like --kv-cache-offloading.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a second thought, for the sake of modularity and extensibility, I think it's fine to expose a block allocator argument, but we should keep the default name concise, and also keep the possibility to accept third-party allocators.

I recommend adding --allocator, which can take several default values as above, and also a class qualname like mod.cls .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this interface work with v1?

@@ -322,10 +322,10 @@ def prepare_worker_input(
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
device="cpu",
device="cuda",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be compatible with other devices I think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As suggested by Kaichao, we will use a static pinned CPU memory buffer to host this array.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we separate the benchmark scripts to another PR to reduce the size of this one?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure!

Copy link

mergify bot commented Dec 17, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ApostaC.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 17, 2024
return static_cast<T*>(tensor.data_ptr());
} else if (device.is_cpu() && tensor.is_pinned()) {
T* ptr;
cudaHostGetDevicePointer((void**)&ptr, static_cast<T*>(tensor.data_ptr()),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save this pointer at the creation of the CPU memory allocation -- making it CUDA graph compatible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially in a new PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend needs-rebase ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants