Skip to content

Commit

Permalink
[Kernel] Explicitly specify other value in tl.load calls (vllm-projec…
Browse files Browse the repository at this point in the history
…t#9014)

Signed-off-by: Angus Wang <[email protected]>
  • Loading branch information
angusYuhao authored Nov 18, 2024
1 parent d782f25 commit 99a49c9
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,22 @@ def _fwd_kernel_inner(
k = tl.load(
k_ptrs + start_n * stride_kt,
mask=offs_n[None, :] + start_n < k_seqlen,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kt,
mask=(offs_n[None, :] + start_n < k_seqlen) &
(offs_d[:, None] < D_HEAD),
other=0.0,
)
else:
if EVEN_D:
k = tl.load(k_ptrs + start_n * stride_kt)
else:
k = tl.load(k_ptrs + start_n * stride_kt,
mask=offs_d[:, None] < D_HEAD)
mask=offs_d[:, None] < D_HEAD,
other=0.0)

qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
Expand Down Expand Up @@ -200,19 +203,22 @@ def _fwd_kernel_inner(
v = tl.load(
v_ptrs + start_n * stride_vt,
mask=offs_n[:, None] + start_n < k_seqlen,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vt,
mask=(offs_n[:, None] + start_n < k_seqlen) &
(offs_d[None, :] < D_HEAD),
other=0.0,
)
else:
if EVEN_D:
v = tl.load(v_ptrs + start_n * stride_vt)
else:
v = tl.load(v_ptrs + start_n * stride_vt,
mask=offs_d[None, :] < D_HEAD)
mask=offs_d[None, :] < D_HEAD,
other=0.0)

acc += tl.dot(p, v)

Expand Down Expand Up @@ -318,12 +324,13 @@ def _fwd_kernel_batch_inference(
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=offs_m[:, None] < q_seqlen,
other=0.0,
)
else:
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
other=0,
other=0.0,
)

sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
Expand Down
4 changes: 3 additions & 1 deletion vllm/lora/ops/bgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def _bgmv_expand_kernel(
other=0.0,
) # [BLOCK_N,BLOCK_K]
if ADD_INPUTS:
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
tiled_out = tl.load(c_ptr + current_n * cn_stride,
mask=c_mask,
other=0.0)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
Expand Down
8 changes: 7 additions & 1 deletion vllm/lora/ops/bgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@ def _bgmv_expand_slice_kernel(
) # [BLOCK_N,BLOCK_K]

if ADD_INPUTS:
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store
# operation uses the same mask, eliminating the risk of garbage
# values propagating
tiled_out = tl.load(c_ptr + current_n * cn_stride,
mask=c_mask,
other=None)
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
else:
accumulator = tl.sum(tiled_a * tiled_b, 1)
Expand Down
5 changes: 4 additions & 1 deletion vllm/lora/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def _sgmv_expand_kernel(
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store operation
# uses the same mask, eliminating the risk of garbage values propagating
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)

Expand Down
5 changes: 4 additions & 1 deletion vllm/lora/ops/sgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def _sgmv_expand_slice_kernel(
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
(slice_offset + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store operation
# uses the same mask, eliminating the risk of garbage values propagating
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)

Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def awq_dequantize_kernel(
result_masks = result_masks_y[:, None] & result_masks_x[None, :]

# Load the weights.
iweights = tl.load(qweight_ptr + offsets, masks)
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
iweights = tl.interleave(iweights, iweights)
Expand Down Expand Up @@ -71,7 +71,7 @@ def awq_dequantize_kernel(
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]

# Load the zeros.
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
Expand All @@ -91,7 +91,7 @@ def awq_dequantize_kernel(
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]

# Load the scales.
scales = tl.load(scales_ptr + scale_offsets, scale_masks)
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Dequantize.
Expand Down Expand Up @@ -165,10 +165,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
masks_k = offsets_k < K
masks_a = masks_am[:, None] & masks_k[None, :]
a = tl.load(a_ptrs, mask=masks_a)
a = tl.load(a_ptrs, mask=masks_a, other=0.0)

masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b)
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
Expand All @@ -181,7 +181,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
masks_zk = offsets_szk < K // group_size
masks_z = masks_zk[:, None] & masks_zn[None, :]
zeros_ptrs = zeros_ptr + offsets_z
zeros = tl.load(zeros_ptrs, mask=masks_z)
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
zeros = tl.interleave(zeros, zeros)
Expand All @@ -191,7 +191,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
masks_sk = offsets_szk < K // group_size
masks_s = masks_sk[:, None] & masks_sn[None, :]
scales_ptrs = scales_ptr + offsets_s
scales = tl.load(scales_ptrs, mask=masks_s)
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))

b = (b >> shifts) & 0xF
Expand Down

0 comments on commit 99a49c9

Please sign in to comment.