Skip to content

Commit 9407312

Browse files
authored
[ET-VK][ez] Align SDPA attention weights S dim to the next multiple of 4
Differential Revision: D86226134 Pull Request resolved: #15578
1 parent 8010c98 commit 9407312

File tree

6 files changed

+18
-13
lines changed

6 files changed

+18
-13
lines changed

backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ void main() {
7676
const int Q_H = q_projected_sizes.y;
7777
// sequence length
7878
const int S = q_projected_sizes.z;
79+
const int S_aligned = align_up_4(S);
7980
// manually determine size of the context_len dim of the attention weight.
8081
// The "actual" tensor sizes may have been aligned to a multiple of 4 to allow
8182
// memory loads to be aligned to texel boundaries.
@@ -96,7 +97,7 @@ void main() {
9697
// number of threads in the work group.
9798
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
9899
VEC4_T in_texel = load_attn_weights_c4(
99-
c4, s, q_h, context_texel_len, S, Q_H);
100+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
100101

101102
for (int comp = 0; comp < 4; comp++) {
102103
local_exp_sum += exp(in_texel[comp]);
@@ -108,7 +109,7 @@ void main() {
108109
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
109110
const int c_base = mul_4(c4);
110111
VEC4_T in_texel = load_attn_weights_c4(
111-
c4, s, q_h, context_texel_len, S, Q_H);
112+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
112113

113114
[[unroll]] for (int comp = 0; comp < 4; comp++) {
114115
if (c_base + comp < context_len) {
@@ -138,19 +139,19 @@ void main() {
138139
// Now go back through each element in the row and normalize
139140
for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) {
140141
VEC4_T in_texel = load_attn_weights_c4(
141-
c4, s, q_h, context_texel_len, S, Q_H);
142+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
142143

143144
VEC4_T out_texel = exp(in_texel) / local_exp_sum;
144145
store_attn_weights_softmax_c4(
145-
out_texel, c4, s, q_h, context_texel_len, S, Q_H);
146+
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
146147
}
147148
// First thread in the work group responsible for handling last texel if it
148149
// contains any padded elements
149150
if (worker_id == 0) {
150151
for (int c4 = C4_limit; c4 < context_texel_len; ++c4) {
151152
const int c_base = mul_4(c4);
152153
VEC4_T in_texel = load_attn_weights_c4(
153-
c4, s, q_h, context_texel_len, S, Q_H);
154+
c4, s, q_h, context_texel_len, S_aligned, Q_H);
154155

155156
// Ensure that padding elements are set to 0.
156157
VEC4_T out_texel = VEC4_T(0);
@@ -160,7 +161,7 @@ void main() {
160161
}
161162
}
162163
store_attn_weights_softmax_c4(
163-
out_texel, c4, s, q_h, context_texel_len, S, Q_H);
164+
out_texel, c4, s, q_h, context_texel_len, S_aligned, Q_H);
164165
}
165166
}
166167
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void main() {
8181
const int Q_H = q_projected_sizes.y;
8282
// sequence length
8383
const int S = q_projected_sizes.z;
84+
const int S_aligned = align_up_4(S);
8485

8586
// number of K/V heads
8687
const int KV_H = k_cache_sizes.y;
@@ -205,7 +206,7 @@ void main() {
205206
s,
206207
q_h,
207208
context_texel_len,
208-
S,
209+
S_aligned,
209210
Q_H);
210211
}
211212
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ void main() {
9393
const int Q_H = q_projected_sizes.y;
9494
// sequence length
9595
const int S = q_projected_sizes.z;
96+
const int S_aligned = align_up_4(S);
9697

9798
// number of K/V heads
9899
const int KV_H = k_cache_sizes.y;
@@ -196,6 +197,6 @@ void main() {
196197
s,
197198
q_h,
198199
context_texel_len,
199-
S,
200+
S_aligned,
200201
Q_H);
201202
}

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void main() {
8181
const int Q_H = q_projected_sizes.y;
8282
// sequence length
8383
const int S = q_projected_sizes.z;
84+
const int S_aligned = align_up_4(S);
8485

8586
// number of K/V heads
8687
const int KV_H = v_cache_sizes.y;
@@ -120,7 +121,7 @@ void main() {
120121
s,
121122
q_h,
122123
context_texel_len,
123-
S,
124+
S_aligned,
124125
Q_H);
125126

126127
load_v_cache_tile_no_checks(
@@ -146,7 +147,7 @@ void main() {
146147
s,
147148
q_h,
148149
context_texel_len,
149-
S,
150+
S_aligned,
150151
Q_H);
151152

152153
load_v_cache_tile_with_checks(

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ void main() {
7575
const int Q_H = q_projected_sizes.y;
7676
// sequence length
7777
const int S = q_projected_sizes.z;
78+
const int S_aligned = align_up_4(S);
7879

7980
// number of K/V heads
8081
const int KV_H = v_cache_sizes.y;
@@ -113,7 +114,7 @@ void main() {
113114
s,
114115
q_h,
115116
context_texel_len,
116-
S,
117+
S_aligned,
117118
Q_H);
118119

119120
load_v_cache_tile_no_checks(
@@ -136,7 +137,7 @@ void main() {
136137
s,
137138
q_h,
138139
context_texel_len,
139-
S,
140+
S_aligned,
140141
Q_H);
141142

142143
load_v_cache_tile_with_checks(

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void resize_compute_attn_weights_node(
5050
std::vector<int64_t> out_sizes = {
5151
1, // batch
5252
num_q_heads,
53-
seq_len,
53+
utils::align_up_4(seq_len),
5454
utils::align_up_4(context_len)};
5555

5656
graph->virtual_resize(attn_weights, out_sizes);

0 commit comments

Comments
 (0)