diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index c1f100c541e38..36dbbc54d62f2 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -57,15 +57,16 @@ def _sgmv_expand_kernel( M = tl.load(seq_lens + cur_batch) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - GROUP_SIZE_M: tl.constexpr = 1 - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + GROUP_M: tl.constexpr = 1 + width = GROUP_M * grid_n + group_id = pid // width + first_pid_m = group_id * GROUP_M + group_idx = pid % width + group_size_m = min(grid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (group_idx % group_size_m) + pid_n = group_idx // group_size_m if pid_m * BLOCK_M > M: return @@ -79,6 +80,7 @@ def _sgmv_expand_kernel( offset_k = tl.arange(0, BLOCK_K) ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + rak = tl.max_contiguous(tl.multiple_of(offset_k % K, BLOCK_K), BLOCK_K) if SAME_STRIDE: cur_lora_d0_stride = ls_d0_ptr @@ -99,9 +101,9 @@ def _sgmv_expand_kernel( a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) + rak[None, :] * input_d2_stride, ) b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - offset_k[:, None] * cur_lora_d2_stride + + rak[:, None] * cur_lora_d2_stride + rbn[None, :] * cur_lora_d1_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): @@ -110,10 +112,10 @@ def _sgmv_expand_kernel( tiled_b = tl.load(b_ptr) else: tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, + mask=rak[None, :] < K - k * BLOCK_K, other=0) tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, + mask=rak[:, None] < K - k * BLOCK_K, other=0) if CAST_TYPE: tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty)