Skip to content

Commit

Permalink
[BugFix] Fix RoPE kernel on long sequences(vllm-project#2164)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Dec 18, 2023
1 parent 8041b73 commit 76a7983
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int query_stride,
const int key_stride,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
Expand All @@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel(
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * query_stride + head_idx * head_size;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
Expand All @@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel(
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
Expand All @@ -89,8 +89,8 @@ void rotary_embedding(
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int query_stride = query.stride(-2);
int key_stride = key.stride(-2);
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
Expand Down

0 comments on commit 76a7983

Please sign in to comment.