Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Performance optimization for swap_blocks by cuda kernels #11531

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,69 @@
#include <cassert>
#include <map>
#include <vector>
#include <cstdio>

#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#endif

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
namespace vllm {

template <typename T>
__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src,
const int64_t* src_to_dst, const int num_pages,
const int num_elements_per_page) {
const int64_t srcPageIdx = src_to_dst[blockIdx.x << 1];
const int64_t dstPageIdx = src_to_dst[(blockIdx.x << 1) | 1];

const int64_t srcPageOffset = srcPageIdx * num_elements_per_page;
const int64_t dstPageOffset = dstPageIdx * num_elements_per_page;

for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) {
dst[dstPageOffset + i] = src[srcPageOffset + i];
}
}

} // namespace vllm

template <int DTYPE_LEN, typename DTYPE>
void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src,
const int64_t* block_mapping_ptr,
const int num_blocks,
const int block_size_in_bytes,
const torch::Device& device) {
c10::cuda::CUDAGuard device_guard(device);

int num_threads = 1024;
int grid_size = num_blocks;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::paged_copy<<<grid_size, num_threads, 0, stream>>>(
dst, src, block_mapping_ptr, num_blocks, block_size_in_bytes / DTYPE_LEN);
}

template <typename T, typename TENSOR_TYPE>
T* get_kernel_ptr(TENSOR_TYPE& tensor) {
// Get the kernel-accessible pointer of the given type T
// Returns NULL if the tensor is on CPU and non-pinned
torch::Device device = tensor.device();
if (device.is_cuda()) {
return static_cast<T*>(tensor.data_ptr());
} else if (device.is_cpu() && tensor.is_pinned()) {
T* ptr;
cudaHostGetDevicePointer((void**)&ptr,
static_cast<void*>(tensor.data_ptr()), 0);
return ptr;
} else if (device.is_cpu()) {
return NULL;
} else {
TORCH_CHECK(false, "Invalid device");
}
}

void swap_blocks_slow(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
Expand Down Expand Up @@ -62,6 +117,36 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
int64_t* src_ptr = get_kernel_ptr<int64_t, torch::Tensor>(src);
int64_t* dst_ptr = get_kernel_ptr<int64_t, torch::Tensor>(dst);
const int64_t* block_mapping_ptr =
get_kernel_ptr<const int64_t, const torch::Tensor>(block_mapping);

if (src_ptr == NULL || dst_ptr == NULL || block_mapping_ptr == NULL) {
// fall back to the slow implementation
swap_blocks_slow(src, dst, block_mapping);
} else {
// Check the device
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
torch::Device block_mapping_device = block_mapping.device();
if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
}
torch::Device cuda_device = src_device.is_cuda() ? src_device : dst_device;

const int64_t num_blocks = block_mapping.size(0);
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();

launch_swap_block_kernel<8, int64_t>(dst_ptr, (const int64_t*)src_ptr,
block_mapping_ptr, num_blocks,
block_size_in_bytes, cuda_device);
}
}

namespace vllm {

// Grid: (num_layers, num_pairs)
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_swap_blocks(
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
device="cpu").view(-1, 2).pin_memory()

# Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory(
Expand Down
26 changes: 26 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def __init__(
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}

# Uninitialized buffer for swapping in/out blocks. Will be initialized
# by initialize_cache.
self.blocks_to_swap_in_buffer: torch.Tensor
self.blocks_to_swap_out_buffer: torch.Tensor

# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
Expand Down Expand Up @@ -274,6 +279,17 @@ def initialize_cache(self, num_gpu_blocks: int,
self._init_cache_engine()
self._warm_up_model()

# Initialize the buffer for swapping in/out blocks.
max_num_blocks = max(num_gpu_blocks, num_cpu_blocks)
self.blocks_to_swap_in_buffer = torch.zeros((max_num_blocks, 2),
dtype=torch.int64,
device="cpu",
pin_memory=True)
self.blocks_to_swap_out_buffer = torch.zeros((max_num_blocks, 2),
dtype=torch.int64,
device="cpu",
pin_memory=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some systems do not have pin memory (notably, WSL). we need to take care of that. otherwise this PR LGTM.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WSL does not support UVA, either. You can use is_pin_memory_available to determine if this optimization can be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sound good, just pushed a new PR fixing this


def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
Expand Down Expand Up @@ -315,6 +331,16 @@ def prepare_worker_input(
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
swap_in_cnt = blocks_to_swap_in.size(0)
swap_out_cnt = blocks_to_swap_out.size(0)

# The buffer will be allocated only if the cache engines are initialized
if hasattr(self, "blocks_to_swap_in_buffer"):
self.blocks_to_swap_in_buffer[:swap_in_cnt] = blocks_to_swap_in
self.blocks_to_swap_out_buffer[:swap_out_cnt] = blocks_to_swap_out
blocks_to_swap_in = self.blocks_to_swap_in_buffer[:swap_in_cnt]
blocks_to_swap_out = self.blocks_to_swap_out_buffer[:swap_out_cnt]

# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
Expand Down
Loading