Skip to content

Commit

Permalink
Lora expand (#4)
Browse files Browse the repository at this point in the history
* L2

Signed-off-by: Abatom <[email protected]>

* L2

Signed-off-by: Abatom <[email protected]>

---------

Signed-off-by: Abatom <[email protected]>
  • Loading branch information
Abatom authored Dec 20, 2024
1 parent 24e893c commit c9747c6
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions vllm/lora/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,16 @@ def _sgmv_expand_kernel(

M = tl.load(seq_lens + cur_batch)

num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
GROUP_SIZE_M: tl.constexpr = 1
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
GROUP_M: tl.constexpr = 1
width = GROUP_M * grid_n
group_id = pid // width
first_pid_m = group_id * GROUP_M
group_idx = pid % width
group_size_m = min(grid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (group_idx % group_size_m)
pid_n = group_idx // group_size_m

if pid_m * BLOCK_M > M:
return
Expand All @@ -79,6 +80,7 @@ def _sgmv_expand_kernel(
offset_k = tl.arange(0, BLOCK_K)
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
rak = tl.max_contiguous(tl.multiple_of(offset_k % K, BLOCK_K), BLOCK_K)

if SAME_STRIDE:
cur_lora_d0_stride = ls_d0_ptr
Expand All @@ -99,9 +101,9 @@ def _sgmv_expand_kernel(

a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride +
ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, )
rak[None, :] * input_d2_stride, )
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
rak[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K)):
Expand All @@ -110,10 +112,10 @@ def _sgmv_expand_kernel(
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :] < K - k * BLOCK_K,
mask=rak[None, :] < K - k * BLOCK_K,
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None] < K - k * BLOCK_K,
mask=rak[:, None] < K - k * BLOCK_K,
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty)
Expand Down

0 comments on commit c9747c6

Please sign in to comment.