Skip to content

Commit a13dc4c

Browse files
authored
[ET-VK][ez] Align SDPA attention weights S dim to the next multiple of 4 (#15599)
Title says it all! Why? * Technically, this is should not be needed but SDPA op was producing incorrect output on Samsung S24 with buffer input tensors. The exact root cause is unclear, but it appears to be an issue specific to the Adreno 750 since it does not reproduce on any other GPU. The best guess at the moment is that we need to ensure that there is no possibility of multiple threads writing to the same memory location. Differential Revision: [D86226134](https://our.internmc.facebook.com/intern/diff/D86226134/)
1 parent 67af512 commit a13dc4c

File tree

6 files changed

+18
-13
lines changed

6 files changed

+18
-13
lines changed

backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ void main() {
7676
const int Q_H = q_projected_sizes.y;
7777
// sequence length
7878
const int S = q_projected_sizes.z;
79+
const int S_aligned = align_up_4(S);
7980
// manually determine size of the context_len dim of the attention weight.
8081
// The "actual" tensor sizes may have been aligned to a multiple of 4 to allow
8182
// memory loads to be aligned to texel boundaries.
@@ -96,7 +97,7 @@ void main() {
9697
// number of threads in the work group.
9798
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
9899
VEC4_T in_texel = load_attn_weights_c4(
99-
c4, s, q_h, context_texel_len, S, Q_H);
100+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
100101

101102
for (int comp = 0; comp < 4; comp++) {
102103
local_exp_sum += exp(in_texel[comp]);
@@ -108,7 +109,7 @@ void main() {
108109
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
109110
const int c_base = mul_4(c4);
110111
VEC4_T in_texel = load_attn_weights_c4(
111-
c4, s, q_h, context_texel_len, S, Q_H);
112+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
112113

113114
[[unroll]] for (int comp = 0; comp < 4; comp++) {
114115
if (c_base + comp < context_len) {
@@ -138,19 +139,19 @@ void main() {
138139
// Now go back through each element in the row and normalize
139140
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
140141
VEC4_T in_texel = load_attn_weights_c4(
141-
c4, s, q_h, context_texel_len, S, Q_H);
142+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
142143

143144
VEC4_T out_texel = exp(in_texel) / local_exp_sum;
144145
store_attn_weights_softmax_c4(
145-
out_texel, c4, s, q_h, context_texel_len, S, Q_H);
146+
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
146147
}
147148
// First thread in the work group responsible for handling last texel if it
148149
// contains any padded elements
149150
if (worker_id == 0) {
150151
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
151152
const int c_base = mul_4(c4);
152153
VEC4_T in_texel = load_attn_weights_c4(
153-
c4, s, q_h, context_texel_len, S, Q_H);
154+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
154155

155156
// Ensure that padding elements are set to 0.
156157
VEC4_T out_texel = VEC4_T(0);
@@ -160,7 +161,7 @@ void main() {
160161
}
161162
}
162163
store_attn_weights_softmax_c4(
163-
out_texel, c4, s, q_h, context_texel_len, S, Q_H);
164+
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
164165
}
165166
}
166167
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void main() {
8181
const int Q_H = q_projected_sizes.y;
8282
// sequence length
8383
const int S = q_projected_sizes.z;
84+
const int S_aligned = align_up_4(S);
8485

8586
// number of K/V heads
8687
const int KV_H = k_cache_sizes.y;
@@ -205,7 +206,7 @@ void main() {
205206
s,
206207
q_h,
207208
context_texel_len,
208-
S,
209+
S_aligned,
209210
Q_H);
210211
}
211212
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ void main() {
9393
const int Q_H = q_projected_sizes.y;
9494
// sequence length
9595
const int S = q_projected_sizes.z;
96+
const int S_aligned = align_up_4(S);
9697

9798
// number of K/V heads
9899
const int KV_H = k_cache_sizes.y;
@@ -196,6 +197,6 @@ void main() {
196197
s,
197198
q_h,
198199
context_texel_len,
199-
S,
200+
S_aligned,
200201
Q_H);
201202
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void main() {
8181
const int Q_H = q_projected_sizes.y;
8282
// sequence length
8383
const int S = q_projected_sizes.z;
84+
const int S_aligned = align_up_4(S);
8485

8586
// number of K/V heads
8687
const int KV_H = v_cache_sizes.y;
@@ -120,7 +121,7 @@ void main() {
120121
s,
121122
q_h,
122123
context_texel_len,
123-
S,
124+
S_aligned,
124125
Q_H);
125126

126127
load_v_cache_tile_no_checks(
@@ -146,7 +147,7 @@ void main() {
146147
s,
147148
q_h,
148149
context_texel_len,
149-
S,
150+
S_aligned,
150151
Q_H);
151152

152153
load_v_cache_tile_with_checks(

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ void main() {
7575
const int Q_H = q_projected_sizes.y;
7676
// sequence length
7777
const int S = q_projected_sizes.z;
78+
const int S_aligned = align_up_4(S);
7879

7980
// number of K/V heads
8081
const int KV_H = v_cache_sizes.y;
@@ -113,7 +114,7 @@ void main() {
113114
s,
114115
q_h,
115116
context_texel_len,
116-
S,
117+
S_aligned,
117118
Q_H);
118119

119120
load_v_cache_tile_no_checks(
@@ -136,7 +137,7 @@ void main() {
136137
s,
137138
q_h,
138139
context_texel_len,
139-
S,
140+
S_aligned,
140141
Q_H);
141142

142143
load_v_cache_tile_with_checks(

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void resize_compute_attn_weights_node(
5050
std::vector<int64_t> out_sizes = {
5151
1, // batch
5252
num_q_heads,
53-
seq_len,
53+
utils::align_up_4(seq_len),
5454
utils::align_up_4(context_len)};
5555

5656
graph->virtual_resize(attn_weights, out_sizes);

0 commit comments

Comments
 (0)