-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Optimize block table transfer from CPU to GPU #11401
Draft
WoosukKwon
wants to merge
22
commits into
main
Choose a base branch
from
v1-blocktable-opt
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
1aaced5
wip
WoosukKwon 8a4180c
yapf
WoosukKwon 03b1e6f
Minor
WoosukKwon 0a669ee
Minor
WoosukKwon ee965c9
Use default
WoosukKwon 0420fb2
Merge branch 'main' into v1-blocktable-opt
WoosukKwon 3fdbd8e
comments
WoosukKwon b938606
Merge branch 'main' into v1-blocktable-opt
WoosukKwon ff5b103
Merge branch 'main' into v1-blocktable-opt
WoosukKwon bef6816
Minor
WoosukKwon 5292219
Add test for uva
WoosukKwon ca4f9e6
minor
WoosukKwon 27e8eb2
Add kernel test
WoosukKwon 34d6cc2
Merge branch 'main' into v1-blocktable-opt
WoosukKwon 6ba31aa
Minor
WoosukKwon ebfbe12
ruff
WoosukKwon a6e5d7b
Merge branch 'main' into v1-blocktable-opt
WoosukKwon 1260e43
Minor
WoosukKwon ba64a02
Minor
WoosukKwon 1ca4298
Fix
WoosukKwon f840b53
fix
WoosukKwon 7097f31
test
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include <torch/all.h> | ||
#include <torch/cuda.h> | ||
|
||
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned | ||
// memory, and that UVA (Unified Virtual Addressing) is enabled. | ||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { | ||
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); | ||
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous"); | ||
|
||
// Get raw host pointer from CPU tensor | ||
void* host_ptr = cpu_tensor.data_ptr(); | ||
|
||
// Get a device pointer corresponding to the pinned host memory | ||
void* device_ptr = nullptr; | ||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); | ||
TORCH_CHECK(err == cudaSuccess, | ||
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); | ||
|
||
// Construct a CUDA tensor from the device pointer. | ||
// We'll use the same sizes, strides, and dtype as the CPU tensor. | ||
auto sizes = cpu_tensor.sizes(); | ||
auto strides = cpu_tensor.strides(); | ||
auto options = | ||
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA | ||
|
||
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter, | ||
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the | ||
// memory, so we don't free it here. | ||
auto deleter = [](void*) { | ||
// no-op, since the memory is owned by the original CPU tensor | ||
}; | ||
|
||
torch::Tensor cuda_tensor = | ||
torch::from_blob(device_ptr, sizes, strides, deleter, options); | ||
|
||
TORCH_CHECK(cuda_tensor.device().is_cuda(), | ||
"Resulting tensor is not on CUDA device"); | ||
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch"); | ||
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch"); | ||
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch"); | ||
|
||
return cuda_tensor; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#include <torch/all.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
namespace vllm { | ||
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src, | ||
const int* __restrict__ matrix_diff, | ||
int* __restrict__ matrix_tgt, int64_t M) { | ||
int row_id = blockIdx.x; | ||
int row_offset = row_id * M; | ||
|
||
int start = matrix_diff[row_id * 2]; | ||
int length = matrix_diff[row_id * 2 + 1]; | ||
int end = start + length; | ||
int thread_idx = threadIdx.x; | ||
for (int i = start + thread_idx; i < end; i += blockDim.x) { | ||
int idx = row_offset + i; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should |
||
matrix_tgt[idx] = matrix_src[idx]; | ||
} | ||
} | ||
} // namespace vllm | ||
|
||
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff, | ||
torch::Tensor& matrix_tgt, int64_t n) { | ||
// NOTE(woosuk): Here, we skip most of the error checking to minimize the | ||
// CPU overheads. We assume that the caller will pass the correct inputs. | ||
|
||
// Check tensor properties | ||
// TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor"); | ||
// TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor"); | ||
// TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor"); | ||
// TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous"); | ||
// TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous"); | ||
// TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous"); | ||
|
||
auto src_sizes = matrix_src.sizes(); | ||
auto diff_sizes = matrix_diff.sizes(); | ||
auto tgt_sizes = matrix_tgt.sizes(); | ||
|
||
// TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D"); | ||
// TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D"); | ||
// TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D"); | ||
|
||
int64_t N = src_sizes[0]; | ||
int64_t M = src_sizes[1]; | ||
|
||
// TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N"); | ||
// TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2"); | ||
// TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M, | ||
// "matrix_tgt must have same shape as matrix_src"); | ||
|
||
// TORCH_CHECK(n <= N, "n must be <= N"); | ||
|
||
const int* d_matrix_src = matrix_src.data_ptr<int>(); | ||
const int* d_matrix_diff = matrix_diff.data_ptr<int>(); | ||
int* d_matrix_tgt = matrix_tgt.data_ptr<int>(); | ||
|
||
// One thread block per row. | ||
int blocks = n; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems this can easily oversubscribe GPU SMs. |
||
int threads; | ||
if (blocks < 128) { | ||
threads = 1024; | ||
} else if (blocks < 256) { | ||
threads = 512; | ||
} else if (blocks < 512) { | ||
threads = 256; | ||
} else { | ||
threads = 128; | ||
} | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>( | ||
d_matrix_src, d_matrix_diff, d_matrix_tgt, M); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import random | ||
|
||
import pytest | ||
import torch | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm.platforms import current_platform | ||
|
||
SEEDS = [0] | ||
CUDA_DEVICES = [ | ||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("seed", SEEDS) | ||
@pytest.mark.parametrize("device", CUDA_DEVICES) | ||
def test_copy_subranges(seed, device): | ||
torch.set_default_device(device) | ||
current_platform.seed_everything(seed) | ||
|
||
num_rows = 1024 | ||
num_cols = 1024 | ||
src_matrix = torch.zeros(num_rows, | ||
num_cols, | ||
device=device, | ||
dtype=torch.int32) | ||
dst_matrix = torch.zeros(num_rows, | ||
num_cols, | ||
device=device, | ||
dtype=torch.int32) | ||
diff_matrix = torch.zeros(num_rows, 2, device=device, dtype=torch.int32) | ||
|
||
for i in range(num_rows): | ||
start_idx = random.randint(0, num_cols - 1) | ||
end_idx = random.randint(start_idx, num_cols - 1) | ||
num_diffs = end_idx - start_idx | ||
|
||
src_matrix[i, start_idx:end_idx] = torch.randint(0, | ||
100, (num_diffs, ), | ||
device=device, | ||
dtype=torch.int32) | ||
|
||
diff_matrix[i, 0] = start_idx | ||
diff_matrix[i, 1] = num_diffs | ||
|
||
ops.copy_subranges(src_matrix, diff_matrix, dst_matrix, num_rows) | ||
assert torch.allclose(src_matrix, dst_matrix, rtol=0, atol=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import pytest | ||
import torch | ||
|
||
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available | ||
|
||
CUDA_DEVICES = [ | ||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) | ||
] | ||
|
||
|
||
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.") | ||
@pytest.mark.parametrize("device", CUDA_DEVICES) | ||
def test_cpu_write(device): | ||
torch.set_default_device(device) | ||
cpu_tensor = torch.zeros(10, | ||
10, | ||
device="cpu", | ||
pin_memory=True, | ||
dtype=torch.int32) | ||
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) | ||
assert cuda_view.device.type == "cuda" | ||
|
||
assert cuda_view[0, 0] == 0 | ||
assert cuda_view[2, 3] == 0 | ||
assert cuda_view[4, 5] == 0 | ||
|
||
cpu_tensor[0, 0] = 1 | ||
cpu_tensor[2, 3] = 2 | ||
cpu_tensor[4, 5] = -1 | ||
|
||
cuda_view.mul_(2) | ||
assert cuda_view[0, 0] == 2 | ||
assert cuda_view[2, 3] == 4 | ||
assert cuda_view[4, 5] == -2 | ||
|
||
|
||
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.") | ||
@pytest.mark.parametrize("device", CUDA_DEVICES) | ||
def test_gpu_write(device): | ||
torch.set_default_device(device) | ||
cpu_tensor = torch.zeros(10, | ||
10, | ||
device="cpu", | ||
pin_memory=True, | ||
dtype=torch.int32) | ||
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) | ||
assert cuda_view.device.type == "cuda" | ||
|
||
assert cuda_view[0, 0] == 0 | ||
assert cuda_view[2, 3] == 0 | ||
assert cuda_view[4, 5] == 0 | ||
|
||
cuda_view[0, 0] = 1 | ||
cuda_view[2, 3] = 2 | ||
cuda_view[4, 5] = -1 | ||
cuda_view.mul_(2) | ||
|
||
assert cpu_tensor[0, 0] == 2 | ||
assert cpu_tensor[2, 3] == 4 | ||
assert cpu_tensor[4, 5] == -2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import pytest | ||
import random | ||
import time | ||
|
||
import torch | ||
|
||
from vllm.v1.worker.gpu_block_table import BlockTable | ||
|
||
MAX_NUM_REQS = 1024 | ||
MAX_MODEL_LEN = 128 * 1024 | ||
BLOCK_SIZE = 16 | ||
MAX_NUM_BLOCKS_PER_REQ = MAX_MODEL_LEN // BLOCK_SIZE | ||
|
||
|
||
def test_block_table(do_wait: bool): | ||
random.seed(0) | ||
torch.manual_seed(0) | ||
torch.cuda.manual_seed_all(0) | ||
|
||
block_table = BlockTable( | ||
max_num_reqs=MAX_NUM_REQS, | ||
max_model_len=MAX_MODEL_LEN, | ||
max_num_blocks_per_req=MAX_NUM_BLOCKS_PER_REQ, | ||
pin_memory=True, | ||
device=torch.device(0), | ||
) | ||
|
||
num_blocks = random.randint(1, MAX_NUM_BLOCKS_PER_REQ - 1) | ||
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (num_blocks,), dtype=torch.int32, device="cpu") | ||
block_table.add_row(0, block_ids) | ||
num_blocks = random.randint(1, MAX_NUM_BLOCKS_PER_REQ - 100) | ||
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (num_blocks,), dtype=torch.int32, device="cpu") | ||
block_table.add_row(1, block_ids) | ||
block_table.commit(2) | ||
|
||
torch.cuda.synchronize() | ||
if do_wait: | ||
time.sleep(1) | ||
|
||
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (100,), dtype=torch.int32, device="cpu") | ||
block_table.append_row(1, num_blocks, block_ids) | ||
block_table.move_row(1, 0) | ||
block_table.commit(2) | ||
|
||
torch.cuda.synchronize() | ||
if do_wait: | ||
time.sleep(1) | ||
|
||
torch.testing.assert_close(block_table.block_table[:1].cpu(), block_table.block_table_cpu[:1]) | ||
|
||
if __name__ == "__main__": | ||
test_block_table(do_wait=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most threads in the block would be idle, e.g. for decoding, there's only one or even no entry changes in the block table.