-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Support offloading KV cache to CPU #10874
Open
ApostaC
wants to merge
13
commits into
vllm-project:main
Choose a base branch
from
KuntaiDu:yihua-cpu-offloading2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f60a8fa
Move to a new branch to fix the DCO issues.
ApostaC e6654f2
[Fix] the failed unit tests
ApostaC ba6c9e3
[Fix] CPU offloading not working bug and [fix] unit test and format i…
ApostaC 1c94985
[fix] broken tests for cpu offloading allocator
ApostaC daab0d6
[Fix] add the call to get_physical_block_ids
ApostaC 919e5e3
[Add] faster unsafe implementation for get_physical_block_id
ApostaC 52185bf
Merge branch 'main' into yihua-cpu-offloading2
ApostaC 505e60c
Updating the benchmark script with correct usage instructions
ApostaC a517a29
make yapf happy
ApostaC 789b00e
fix format checker issues
ApostaC a0b5061
Merge remote-tracking branch 'upstream/main' into yihua-cpu-offloading2
ApostaC 2648fa5
Fix the compatibility witht the latest main
ApostaC 278e166
Fix the typo uncacherd -> uncached
ApostaC File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
""" | ||
Benchmark the efficiency of prefix caching. | ||
|
||
This script allows you to benchmark the performance of | ||
a model with and without prefix caching using either fixed prompts | ||
or prompts sampled from the ShareGPT dataset. | ||
|
||
Fixed example usage: | ||
python benchmark_prefix_caching.py \ | ||
--model meta-llama/Llama-2-7b-chat-hf \ | ||
--enable-prefix-caching \ | ||
--num-prompts 1 \ | ||
--repeat-count 100 | ||
|
||
ShareGPT example usage: | ||
# This command samples 20 prompts with input lengths | ||
# between 128 and 256 tokens from the ShareGPT dataset, | ||
# then replicates each prompt 5 times. | ||
python benchmark_prefix_caching.py \ | ||
--model meta-llama/Llama-2-7b-chat-hf \ | ||
--dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \ | ||
--enable-prefix-caching \ | ||
--num-prompts 20 \ | ||
--repeat-count 5 \ | ||
--input-length-range 128:256 | ||
""" | ||
|
||
import random | ||
import time | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.utils import FlexibleArgumentParser | ||
|
||
|
||
def test_long_document_qa(llm=None, sampling_params=None, prompts=None): | ||
|
||
start_time = time.time() | ||
llm.generate(prompts, sampling_params=sampling_params) | ||
end_time = time.time() | ||
print(f"cost time {end_time - start_time}") | ||
|
||
|
||
def repeat_prompts(prompts, repeat_count): | ||
repeated_prompts = prompts * repeat_count | ||
random.shuffle(repeated_prompts) | ||
return repeated_prompts | ||
|
||
|
||
def main(args): | ||
|
||
random.seed(args.seed) | ||
|
||
# append the document id at the beginning to avoid any of the document | ||
# being the prefix of other documents | ||
prompts = [ | ||
str(i) + ' '.join(['hi'] * args.document_length) | ||
for i in range(args.num_documents) | ||
] | ||
|
||
preemption_mode = "" | ||
if args.block_allocator == "CpuOffloadingBlockAllocator": | ||
preemption_mode = "recompute" | ||
else: | ||
preemption_mode = "swap" | ||
|
||
llm = LLM(model=args.model, | ||
tokenizer_mode='auto', | ||
trust_remote_code=True, | ||
enforce_eager=True, | ||
tensor_parallel_size=args.tensor_parallel_size, | ||
enable_prefix_caching=args.enable_prefix_caching, | ||
block_allocator=args.block_allocator, | ||
preemption_mode=preemption_mode, | ||
swap_space=args.cpu_memory_gb, | ||
gpu_memory_utilization=args.gpu_memory_utilization, | ||
max_model_len=30000) | ||
|
||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) | ||
|
||
prompts = repeat_prompts(prompts, args.repeat_count) | ||
|
||
print("------warm up------") | ||
test_long_document_qa( | ||
llm=llm, | ||
prompts=prompts, | ||
sampling_params=sampling_params, | ||
) | ||
|
||
print("------start generating------") | ||
test_long_document_qa( | ||
llm=llm, | ||
prompts=prompts, | ||
sampling_params=sampling_params, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = FlexibleArgumentParser( | ||
description= | ||
'Benchmark the performance with or without automatic prefix caching.') | ||
parser.add_argument( | ||
'--model', | ||
type=str, | ||
# this test aims to test long document QA capability, | ||
# so we use llama 3.1 8B as it can process long context | ||
default='meta-llama/Llama-3.1-8B') | ||
parser.add_argument("--dataset-path", | ||
type=str, | ||
default=None, | ||
help="Path to the dataset.") | ||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) | ||
parser.add_argument('--output-len', type=int, default=10) | ||
parser.add_argument('--enable-prefix-caching', | ||
action='store_true', | ||
help='enable prefix caching') | ||
parser.add_argument('--repeat-count', | ||
type=int, | ||
default=2, | ||
help='Number of times to repeat each prompt') | ||
parser.add_argument( | ||
'--document-length', | ||
type=int, | ||
# Roughly the number of tokens for a system paper, | ||
# excluding images | ||
default=20010, | ||
help='Range of input lengths for sampling prompts,' | ||
'specified as "min:max" (e.g., "128:256").') | ||
parser.add_argument('--num-documents', | ||
type=int, | ||
default=8, | ||
help='Range of input lengths for sampling prompts,' | ||
'specified as "min:max" (e.g., "128:256").') | ||
parser.add_argument("--seed", | ||
type=int, | ||
default=0, | ||
help='Random seed for reproducibility') | ||
parser.add_argument('--gpu-memory-utilization', | ||
type=float, | ||
default=0.5, | ||
help='GPU memory utilization for vLLM. Should be a ' | ||
'float point number ranging from 0 to 1. For this ' | ||
'test please use a small value so that the GPU ' | ||
'cannot hold all KV caches of all documents, ' | ||
'and the effect of CPU offloading can be tested.') | ||
parser.add_argument( | ||
'--cpu-memory-gb', | ||
type=float, | ||
default=1, | ||
help="The amount of CPU memory (GB) that is used by vLLM. Not very " | ||
"useful for CpuGpuBlockAllocator, but useful for " | ||
"CpuOffloadingBlockAllocator to have more CPU KV cache space") | ||
parser.add_argument( | ||
'--block-allocator', | ||
type=str, | ||
default='CpuGpuBlockAllocator', | ||
choices=['CpuGpuBlockAllocator', 'CpuOffloadingBlockAllocator'], | ||
help='The block allocator that vLLM uses. Currently' | ||
' can be CpuGpuBlockAllocator (the default) and ' | ||
'CpuOffloadingBlockAllocator (experimental) that ' | ||
'supports offloading the KV cache to CPU . ' | ||
'When using CpuOffloadingBlockAllocator, the ' | ||
'preemption mode must be recompute.') | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -244,4 +244,4 @@ def main(args): | |
|
||
parser = EngineArgs.add_cli_args(parser) | ||
args = parser.parse_args() | ||
main(args) | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
#include "quantization/fp8/nvidia/quant_utils.cuh" | ||
#endif | ||
|
||
#include <cstdio> | ||
#include <algorithm> | ||
#include <cassert> | ||
#include <map> | ||
|
@@ -21,8 +22,64 @@ | |
typedef __hip_bfloat16 __nv_bfloat16; | ||
#endif | ||
|
||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, | ||
const torch::Tensor& block_mapping) { | ||
namespace vllm { | ||
|
||
template <typename T, typename ACC_T> | ||
__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src, | ||
ACC_T src_to_dst, const int num_pages, | ||
const int num_elements_per_page) { | ||
const int64_t srcPageIdx = src_to_dst[blockIdx.x][0]; | ||
const int64_t dstPageIdx = src_to_dst[blockIdx.x][1]; | ||
|
||
const int64_t srcPageOffset = srcPageIdx * num_elements_per_page; | ||
const int64_t dstPageOffset = dstPageIdx * num_elements_per_page; | ||
|
||
for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) { | ||
dst[dstPageOffset + i] = src[srcPageOffset + i]; | ||
} | ||
} | ||
|
||
} // namespace vllm | ||
|
||
template <int DTYPE_LEN, typename DTYPE> | ||
void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src, | ||
const torch::Tensor& block_mapping, | ||
const int num_blocks, | ||
const int block_size_in_bytes) { | ||
c10::cuda::CUDAGuard device_guard(block_mapping.device()); | ||
auto block_mapping_accessor = | ||
block_mapping.packed_accessor32<int64_t, 2, torch::RestrictPtrTraits>(); | ||
|
||
int num_threads = 1024; | ||
int grid_size = num_blocks; | ||
|
||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
vllm::paged_copy<<<grid_size, num_threads, 0, stream>>>( | ||
dst, src, block_mapping_accessor, num_blocks, | ||
block_size_in_bytes / DTYPE_LEN); | ||
} | ||
|
||
template <typename T> | ||
T* get_kernel_ptr(torch::Tensor& tensor) { | ||
// Get the kernel-accessible pointer of the given type T | ||
// Returns NULL if the tensor is on CPU and non-pinned | ||
torch::Device device = tensor.device(); | ||
if (device.is_cuda()) { | ||
return static_cast<T*>(tensor.data_ptr()); | ||
} else if (device.is_cpu() && tensor.is_pinned()) { | ||
T* ptr; | ||
cudaHostGetDevicePointer((void**)&ptr, static_cast<T*>(tensor.data_ptr()), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Save this pointer at the creation of the CPU memory allocation -- making it CUDA graph compatible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potentially in a new PR |
||
0); | ||
return ptr; | ||
} else if (device.is_cpu()) { | ||
return NULL; | ||
} else { | ||
TORCH_CHECK(false, "Invalid device"); | ||
} | ||
} | ||
|
||
void swap_blocks_slow(torch::Tensor& src, torch::Tensor& dst, | ||
const torch::Tensor& block_mapping) { | ||
torch::Device src_device = src.device(); | ||
torch::Device dst_device = dst.device(); | ||
cudaMemcpyKind memcpy_type; | ||
|
@@ -62,6 +119,41 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, | |
} | ||
} | ||
|
||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, | ||
const torch::Tensor& block_mapping) { | ||
int64_t* src_ptr = get_kernel_ptr<int64_t>(src); | ||
int64_t* dst_ptr = get_kernel_ptr<int64_t>(dst); | ||
if (src_ptr == NULL || dst_ptr == NULL) { | ||
// fall back to the slow implementation | ||
swap_blocks_slow(src, dst, block_mapping.cpu()); | ||
} else { | ||
// Check the device | ||
torch::Device src_device = src.device(); | ||
torch::Device dst_device = dst.device(); | ||
torch::Device block_mapping_device = block_mapping.device(); | ||
TORCH_CHECK(block_mapping_device.is_cuda(), "block_mapping must be on GPU"); | ||
if (src_device.is_cuda() && dst_device.is_cuda()) { | ||
TORCH_CHECK(src_device.index() == dst_device.index(), | ||
"src and dst must be on the same GPU"); | ||
} | ||
if (src_device.is_cuda()) { | ||
TORCH_CHECK(src_device.index() == block_mapping_device.index(), | ||
"src and block_mapping must be on the same GPU"); | ||
} | ||
if (dst_device.is_cuda()) { | ||
TORCH_CHECK(dst_device.index() == block_mapping_device.index(), | ||
"src and block_mapping must be on the same GPU"); | ||
} | ||
|
||
const int64_t num_blocks = block_mapping.size(0); | ||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); | ||
|
||
launch_swap_block_kernel<8, int64_t>(dst_ptr, (const int64_t*)src_ptr, | ||
block_mapping, num_blocks, | ||
block_size_in_bytes); | ||
} | ||
} | ||
|
||
namespace vllm { | ||
|
||
// Grid: (num_layers, num_pairs) | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we separate the benchmark scripts to another PR to reduce the size of this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sure!