diff --git a/rllm/kernels/cache.cu b/rllm/kernels/cache.cu deleted file mode 100644 index 004e210f..00000000 --- a/rllm/kernels/cache.cu +++ /dev/null @@ -1,266 +0,0 @@ -// based on https://github.com/vllm-project/vllm/blob/b9fe4616f98b77b4b9458bce203aa6544cb31ef2/csrc/cache_kernels.cu - - -#include -#include -#include - -// void swap_blocks(...) -> implement in Rust - -// Grid: (num_layers, num_pairs) -template -__global__ void copy_blocks_kernel( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int* __restrict__ block_mapping, - const int numel_per_block) { - const int layer_idx = blockIdx.x; - const int pair_idx = blockIdx.y; - - scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); - int src_block_number = block_mapping[2 * pair_idx]; - int dst_block_number = block_mapping[2 * pair_idx + 1]; - - const int src_block_offset = src_block_number * numel_per_block; - const int dst_block_offset = dst_block_number * numel_per_block; - for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { - int src_offset = src_block_offset + i; - int dst_offset = dst_block_offset + i; - key_cache[dst_offset] = key_cache[src_offset]; - } - for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { - int src_offset = src_block_offset + i; - int dst_offset = dst_block_offset + i; - value_cache[dst_offset] = value_cache[src_offset]; - } -} - -extern "C" void copy_blocks_bf16( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int* __restrict__ block_mapping, - const int numel_per_block, - const int num_pairs, - const int num_layers) -{ - dim3 grid(num_layers, num_pairs); - dim3 block(std::min(1024, numel_per_block)); - const cudaStream_t stream = 0; - copy_blocks_kernel<__nv_bfloat16><<>>( - key_cache_ptrs, - value_cache_ptrs, - block_mapping, - numel_per_block); -} - - -template -__global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) { - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; - - const int n = num_heads * head_size; - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int src_key_idx = token_idx * key_stride + i; - const int src_value_idx = token_idx * value_stride + i; - - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int x_idx = head_offset / x; - const int x_offset = head_offset % x; - - const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); - value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); - } -} - - -// Grid: (num_blocks, block_size). -template -__global__ void gather_cached_kv_kernel( - scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] - scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] - const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) { - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; - - const int num_tokens = num_heads * head_size; - for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { - const int tgt_key_idx = token_idx * key_stride + i; - const int tgt_value_idx = token_idx * value_stride + i; - - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension - const int x_offset = head_offset % x; - - const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int src_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); - } -} - -template -__global__ void gather_cached_kv_kernel_optimized( - scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size] - scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size] - const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int *__restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) -{ - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; - - const int dim = num_heads * head_size; - assert(dim % 4 == 0); // this is true for known use cases - const int unroll_factor = 4; - const int unrolled_dim = dim / unroll_factor; - - for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x) - { - int tgt_key_indices[unroll_factor]; - int tgt_value_indices[unroll_factor]; - int src_key_indices[unroll_factor]; - int src_value_indices[unroll_factor]; - scalar_t keys_to_store[unroll_factor]; - scalar_t values_to_store[unroll_factor]; - - #pragma unroll - for (int j = 0; j < unroll_factor; ++j) - { - int index = i + j * unrolled_dim; - - const int tgt_key_idx = token_idx * key_stride + index; - const int tgt_value_idx = token_idx * value_stride + index; - - const int head_idx = index / head_size; - const int head_offset = index % head_size; - const int x_idx = head_offset / x; - const int x_offset = head_offset % x; - - const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int src_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - - tgt_key_indices[j] = tgt_key_idx; - tgt_value_indices[j] = tgt_value_idx; - src_key_indices[j] = src_key_idx; - src_value_indices[j] = src_value_idx; - - keys_to_store[j] = __ldg(&key_cache[src_key_idx]); - values_to_store[j] = __ldg(&value_cache[src_value_idx]); - } - - #pragma unroll - for (int j = 0; j < unroll_factor; ++j) - { - key[tgt_key_indices[j]] = keys_to_store[j]; - value[tgt_value_indices[j]] = values_to_store[j]; - } - } -} - - -extern "C" -void gather_scatter_inner_bf16( - __nv_bfloat16* __restrict__ key, // [num_tokens, num_heads, head_size] - __nv_bfloat16* __restrict__ value, // [num_tokens, num_heads, head_size] - __nv_bfloat16* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - __nv_bfloat16* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x, - const int num_tokens, - const int op) -{ - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - const cudaStream_t stream = 0; - if (op == 0) - reshape_and_cache_kernel<__nv_bfloat16><<>>( - key, - value, - key_cache, - value_cache, - slot_mapping, - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - else - gather_cached_kv_kernel_optimized<__nv_bfloat16><<>>( - key, - value, - key_cache, - value_cache, - slot_mapping, - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); -} - diff --git a/rllm/kernels/pos_encoding.cu b/rllm/kernels/pos_encoding.cu deleted file mode 100644 index 062035a0..00000000 --- a/rllm/kernels/pos_encoding.cu +++ /dev/null @@ -1,110 +0,0 @@ -// adapted from https://github.com/vllm-project/vllm/blob/b9fe4616f98b77b4b9458bce203aa6544cb31ef2/csrc/pos_encoding_kernels.cu - -#include -#include -#include - - -template -inline __device__ void apply_rotary_embedding( - scalar_t* __restrict__ arr, - const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, - int rot_offset, - int embed_dim) -{ - int x_index, y_index; - scalar_t cos, sin; - if (IS_NEOX) { - // GPT-NeoX style rotary embedding. - x_index = rot_offset; - y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); - } else { - // GPT-J style rotary embedding. - x_index = 2 * rot_offset; - y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); - } - - const scalar_t x = arr[x_index]; - const scalar_t y = arr[y_index]; - arr[x_index] = x * cos - y * sin; - arr[y_index] = y * cos + x * sin; -} - -template -__global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [num_tokens] - scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int query_stride, - const int key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - - const int embed_dim = rot_dim / 2; - const scalar_t* cos_ptr = cache_ptr; - const scalar_t* sin_ptr = cache_ptr + embed_dim; - - const int nq = num_heads * embed_dim; - for (int i = threadIdx.x; i < nq; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int token_head = token_idx * query_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - if (token_idx == 0 && - token_head + rot_offset + embed_dim >= query_stride - ) { - // token_head: 3968, rot_offset: 63, embed_dim: 128, query_stride: 4096 - - printf("i: %d, head_idx: %d, head_size: %d, token_head: %d, rot_offset: %d, embed_dim: %d, query_stride: %d\n", - i, head_idx, head_size, - token_head, rot_offset, embed_dim, query_stride); - assert(false); - } - apply_rotary_embedding(query + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); - } - - const int nk = num_kv_heads * embed_dim; - for (int i = threadIdx.x; i < nk; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int token_head = token_idx * key_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_rotary_embedding(key + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); - } -} - -#define scalar_t __nv_bfloat16 - -extern "C" void rotary_embedding_bf16( - const int64_t* __restrict__ positions, // [num_tokens] - scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int32_t num_tokens, - const int32_t rot_dim, - const int32_t query_stride, - const int32_t key_stride, - const int32_t num_heads, - const int32_t num_kv_heads, - const int32_t head_size) -{ - - dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const cudaStream_t stream = 0; // Use the default stream. - // const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - rotary_embedding_kernel<<>>( - positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); -}