Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Optimize block table transfer from CPU to GPU #11401

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,15 @@ set(VLLM_EXT_SRC
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/copy_subranges.cu"
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down
8 changes: 8 additions & 0 deletions csrc/cuda_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif

// #ifndef USE_ROCM
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
// cudaHostGetDevicePointer(device_ptr, host_ptr, flags)
// #else
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
// hipHostGetDevicePointer(device_ptr, host_ptr, flags)
// #endif
43 changes: 43 additions & 0 deletions csrc/cuda_view.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <torch/all.h>
#include <torch/cuda.h>

// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous");

// Get raw host pointer from CPU tensor
void* host_ptr = cpu_tensor.data_ptr();

// Get a device pointer corresponding to the pinned host memory
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));

// Construct a CUDA tensor from the device pointer.
// We'll use the same sizes, strides, and dtype as the CPU tensor.
auto sizes = cpu_tensor.sizes();
auto strides = cpu_tensor.strides();
auto options =
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA

// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
// memory, so we don't free it here.
auto deleter = [](void*) {
// no-op, since the memory is owned by the original CPU tensor
};

torch::Tensor cuda_tensor =
torch::from_blob(device_ptr, sizes, strides, deleter, options);

TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch");
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch");
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch");

return cuda_tensor;
}
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);

void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n);

torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
Expand Down
75 changes: 75 additions & 0 deletions csrc/prepare_inputs/copy_subranges.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

namespace vllm {
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,
const int* __restrict__ matrix_diff,
int* __restrict__ matrix_tgt, int64_t M) {
int row_id = blockIdx.x;
int row_offset = row_id * M;

int start = matrix_diff[row_id * 2];
int length = matrix_diff[row_id * 2 + 1];
int end = start + length;
int thread_idx = threadIdx.x;
for (int i = start + thread_idx; i < end; i += blockDim.x) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most threads in the block would be idle, e.g. for decoding, there's only one or even no entry changes in the block table.

int idx = row_offset + i;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should row_offset and idx be int64_t? I.e. could they overflow an int32?

matrix_tgt[idx] = matrix_src[idx];
}
}
} // namespace vllm

void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n) {
// NOTE(woosuk): Here, we skip most of the error checking to minimize the
// CPU overheads. We assume that the caller will pass the correct inputs.

// Check tensor properties
// TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
// TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
// TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
// TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
// TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
// TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");

auto src_sizes = matrix_src.sizes();
auto diff_sizes = matrix_diff.sizes();
auto tgt_sizes = matrix_tgt.sizes();

// TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
// TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
// TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");

int64_t N = src_sizes[0];
int64_t M = src_sizes[1];

// TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
// TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
// TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
// "matrix_tgt must have same shape as matrix_src");

// TORCH_CHECK(n <= N, "n must be <= N");

const int* d_matrix_src = matrix_src.data_ptr<int>();
const int* d_matrix_diff = matrix_diff.data_ptr<int>();
int* d_matrix_tgt = matrix_tgt.data_ptr<int>();

// One thread block per row.
int blocks = n;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this can easily oversubscribe GPU SMs.

int threads;
if (blocks < 128) {
threads = 1024;
} else if (blocks < 256) {
threads = 512;
} else if (blocks < 512) {
threads = 256;
} else {
threads = 128;
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>(
d_matrix_src, d_matrix_diff, d_matrix_tgt, M);
}
9 changes: 9 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);

ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);

// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
Expand Down Expand Up @@ -98,6 +102,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);

ops.def(
"copy_subranges(Tensor matrix_src, Tensor matrix_diff, Tensor! "
"matrix_tgt, int n) -> ()");
ops.impl("copy_subranges", torch::kCUDA, &copy_subranges);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
47 changes: 47 additions & 0 deletions tests/kernels/test_copy_subranges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import random

import pytest
import torch

from vllm import _custom_ops as ops
from vllm.platforms import current_platform

SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_copy_subranges(seed, device):
torch.set_default_device(device)
current_platform.seed_everything(seed)

num_rows = 1024
num_cols = 1024
src_matrix = torch.zeros(num_rows,
num_cols,
device=device,
dtype=torch.int32)
dst_matrix = torch.zeros(num_rows,
num_cols,
device=device,
dtype=torch.int32)
diff_matrix = torch.zeros(num_rows, 2, device=device, dtype=torch.int32)

for i in range(num_rows):
start_idx = random.randint(0, num_cols - 1)
end_idx = random.randint(start_idx, num_cols - 1)
num_diffs = end_idx - start_idx

src_matrix[i, start_idx:end_idx] = torch.randint(0,
100, (num_diffs, ),
device=device,
dtype=torch.int32)

diff_matrix[i, 0] = start_idx
diff_matrix[i, 1] = num_diffs

ops.copy_subranges(src_matrix, diff_matrix, dst_matrix, num_rows)
assert torch.allclose(src_matrix, dst_matrix, rtol=0, atol=0)
60 changes: 60 additions & 0 deletions tests/kernels/test_uva.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest
import torch

from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available

CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10,
10,
device="cpu",
pin_memory=True,
dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"

assert cuda_view[0, 0] == 0
assert cuda_view[2, 3] == 0
assert cuda_view[4, 5] == 0

cpu_tensor[0, 0] = 1
cpu_tensor[2, 3] = 2
cpu_tensor[4, 5] = -1

cuda_view.mul_(2)
assert cuda_view[0, 0] == 2
assert cuda_view[2, 3] == 4
assert cuda_view[4, 5] == -2


@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_gpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10,
10,
device="cpu",
pin_memory=True,
dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"

assert cuda_view[0, 0] == 0
assert cuda_view[2, 3] == 0
assert cuda_view[4, 5] == 0

cuda_view[0, 0] = 1
cuda_view[2, 3] = 2
cuda_view[4, 5] = -1
cuda_view.mul_(2)

assert cpu_tensor[0, 0] == 2
assert cpu_tensor[2, 3] == 4
assert cpu_tensor[4, 5] == -2
52 changes: 52 additions & 0 deletions tests/v1/worker/test_gpu_block_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import random
import time

import torch

from vllm.v1.worker.gpu_block_table import BlockTable

MAX_NUM_REQS = 1024
MAX_MODEL_LEN = 128 * 1024
BLOCK_SIZE = 16
MAX_NUM_BLOCKS_PER_REQ = MAX_MODEL_LEN // BLOCK_SIZE


def test_block_table(do_wait: bool):
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

block_table = BlockTable(
max_num_reqs=MAX_NUM_REQS,
max_model_len=MAX_MODEL_LEN,
max_num_blocks_per_req=MAX_NUM_BLOCKS_PER_REQ,
pin_memory=True,
device=torch.device(0),
)

num_blocks = random.randint(1, MAX_NUM_BLOCKS_PER_REQ - 1)
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (num_blocks,), dtype=torch.int32, device="cpu")
block_table.add_row(0, block_ids)
num_blocks = random.randint(1, MAX_NUM_BLOCKS_PER_REQ - 100)
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (num_blocks,), dtype=torch.int32, device="cpu")
block_table.add_row(1, block_ids)
block_table.commit(2)

torch.cuda.synchronize()
if do_wait:
time.sleep(1)

block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (100,), dtype=torch.int32, device="cpu")
block_table.append_row(1, num_blocks, block_ids)
block_table.move_row(1, 0)
block_table.commit(2)

torch.cuda.synchronize()
if do_wait:
time.sleep(1)

torch.testing.assert_close(block_table.block_table[:1].cpu(), block_table.block_table_cpu[:1])

if __name__ == "__main__":
test_block_table(do_wait=False)
14 changes: 14 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
block_table_bound)


# copy subrange op. Used for input preparation in the vLLM V1 GPU backend.
def copy_subranges(
src_matrix: torch.Tensor,
diff_matrix: torch.Tensor,
tgt_matrix: torch.Tensor,
num_subranges: int,
) -> None:
# NOTE(woosuk): We use `torch.ops._C.copy_subranges.default` instead of
# `torch.ops._C.copy_subranges` to avoid unnecessary CPU overheads from
# the dispatcher.
torch.ops._C.copy_subranges.default(src_matrix, diff_matrix, tgt_matrix,
num_subranges)


# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
input: torch.Tensor,
Expand Down
16 changes: 16 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,14 @@ def is_pin_memory_available() -> bool:
return current_platform.is_pin_memory_available()


@lru_cache(maxsize=None)
def is_uva_available() -> bool:
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO(woosuk): Add more requirements for UVA.
return is_pin_memory_available()


class DeviceMemoryProfiler:

def __init__(self, device: Optional[torch.types.Device] = None):
Expand Down Expand Up @@ -1557,6 +1565,14 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")


def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)


def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
Expand Down
Loading
Loading