Skip to content

Commit 03475fe

Browse files
committed
[GPU] sdpa_micro: moidfy attn mask tile loading
1 parent 762eb8a commit 03475fe

File tree

1 file changed

+15
-1
lines changed
  • src/plugins/intel_gpu/src/kernel_selector/cl_kernels

1 file changed

+15
-1
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,21 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
404404
#if WITH_ATTN_MASK
405405
/* Load mask. No remainder handling needed assuming k block size is a power of 2. */
406406
mask_tile_type mask_tile;
407-
tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
407+
408+
// Check if attention mask has a single Query dimension (e.g., [batch, num_heads, 1, sequence_length])
409+
if (MSK_D2 == 1) {
410+
// Define mask dimensions for single Query dimension
411+
uint mask_m = MSK_D1; // num_heads
412+
uint mask_n = MSK_D3; // sequence_length
413+
414+
tile_load_t(&mask_tile, msk, mask_m, mask_n, 0, k0 + sg_i0_kq);
415+
} else {
416+
// General case: attention mask matches Q*K^T shape
417+
uint mask_m = q; // Q sequence length
418+
uint mask_n = k; // K sequence length
419+
420+
tile_load_t(&mask_tile, msk, mask_m, mask_n, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
421+
}
408422
#endif
409423

410424
#if REMAINDER_K

0 commit comments

Comments
 (0)