diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a95279f9a25a..06fc18f076826 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -15,14 +15,69 @@ #include #include #include +#include #ifdef USE_ROCM #include typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping) { +namespace vllm { + +template +__global__ void paged_copy(T* __restrict__ dst, const T* __restrict__ src, + const int64_t* src_to_dst, const int num_pages, + const int num_elements_per_page) { + const int64_t srcPageIdx = src_to_dst[blockIdx.x << 1]; + const int64_t dstPageIdx = src_to_dst[(blockIdx.x << 1) | 1]; + + const int64_t srcPageOffset = srcPageIdx * num_elements_per_page; + const int64_t dstPageOffset = dstPageIdx * num_elements_per_page; + + for (int i = threadIdx.x; i < num_elements_per_page; i += blockDim.x) { + dst[dstPageOffset + i] = src[srcPageOffset + i]; + } +} + +} // namespace vllm + +template +void launch_swap_block_kernel(DTYPE* dst, const DTYPE* src, + const int64_t* block_mapping_ptr, + const int num_blocks, + const int block_size_in_bytes, + const torch::Device& device) { + c10::cuda::CUDAGuard device_guard(device); + + int num_threads = 1024; + int grid_size = num_blocks; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::paged_copy<<>>( + dst, src, block_mapping_ptr, num_blocks, block_size_in_bytes / DTYPE_LEN); +} + +template +T* get_kernel_ptr(TENSOR_TYPE& tensor) { + // Get the kernel-accessible pointer of the given type T + // Returns NULL if the tensor is on CPU and non-pinned + torch::Device device = tensor.device(); + if (device.is_cuda()) { + return static_cast(tensor.data_ptr()); + } else if (device.is_cpu() && tensor.is_pinned()) { + T* ptr; + cudaHostGetDevicePointer((void**)&ptr, + static_cast(tensor.data_ptr()), 0); + return ptr; + } else if (device.is_cpu()) { + return NULL; + } else { + TORCH_CHECK(false, "Invalid device"); + } +} + +void swap_blocks_slow(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; @@ -62,6 +117,36 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, } } +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { + int64_t* src_ptr = get_kernel_ptr(src); + int64_t* dst_ptr = get_kernel_ptr(dst); + const int64_t* block_mapping_ptr = + get_kernel_ptr(block_mapping); + + if (src_ptr == NULL || dst_ptr == NULL || block_mapping_ptr == NULL) { + // fall back to the slow implementation + swap_blocks_slow(src, dst, block_mapping); + } else { + // Check the device + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + torch::Device block_mapping_device = block_mapping.device(); + if (src_device.is_cuda() && dst_device.is_cuda()) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + } + torch::Device cuda_device = src_device.is_cuda() ? src_device : dst_device; + + const int64_t num_blocks = block_mapping.size(0); + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + + launch_swap_block_kernel<8, int64_t>(dst_ptr, (const int64_t*)src_ptr, + block_mapping_ptr, num_blocks, + block_size_in_bytes, cuda_device); + } +} + namespace vllm { // Grid: (num_layers, num_pairs) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 40550ed51e2c7..04a1057838ea6 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -362,7 +362,7 @@ def test_swap_blocks( block_mapping = list(zip(src_blocks, dst_blocks)) block_mapping_tensor = torch.tensor(block_mapping, dtype=torch.int64, - device="cpu").view(-1, 2) + device="cpu").view(-1, 2).pin_memory() # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f51b51d433d3d..8c2f3775b5dae 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -21,7 +21,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.utils import GiB_bytes, memory_profiling +from vllm.utils import GiB_bytes, is_pin_memory_available, memory_profiling from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -95,6 +95,11 @@ def __init__( self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} + # Uninitialized buffer for swapping in/out blocks. Will be initialized + # by initialize_cache. + self.blocks_to_swap_in_buffer: torch.Tensor + self.blocks_to_swap_out_buffer: torch.Tensor + # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: @@ -274,6 +279,18 @@ def initialize_cache(self, num_gpu_blocks: int, self._init_cache_engine() self._warm_up_model() + # Initialize the buffer for swapping in/out blocks. + max_num_blocks = max(num_gpu_blocks, num_cpu_blocks) + use_pin_memory = is_pin_memory_available() + self.blocks_to_swap_in_buffer = torch.zeros((max_num_blocks, 2), + dtype=torch.int64, + device="cpu", + pin_memory=use_pin_memory) + self.blocks_to_swap_out_buffer = torch.zeros((max_num_blocks, 2), + dtype=torch.int64, + device="cpu", + pin_memory=use_pin_memory) + def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ @@ -315,6 +332,16 @@ def prepare_worker_input( blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, device="cpu", dtype=torch.int64).view(-1, 2) + swap_in_cnt = blocks_to_swap_in.size(0) + swap_out_cnt = blocks_to_swap_out.size(0) + + # The buffer will be allocated only if the cache engines are initialized + if hasattr(self, "blocks_to_swap_in_buffer"): + self.blocks_to_swap_in_buffer[:swap_in_cnt] = blocks_to_swap_in + self.blocks_to_swap_out_buffer[:swap_out_cnt] = blocks_to_swap_out + blocks_to_swap_in = self.blocks_to_swap_in_buffer[:swap_in_cnt] + blocks_to_swap_out = self.blocks_to_swap_out_buffer[:swap_out_cnt] + # `blocks_to_copy` is a gpu tensor. The src and tgt of # blocks to copy are in the same device, and `blocks_to_copy` # can be used directly within cuda kernels.